Skip to content

Commit b96206a

Browse files
committed
Support FPGrowth algorithm in Python API
1 parent 5677557 commit b96206a

File tree

3 files changed

+103
-0
lines changed

3 files changed

+103
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.api.python.SerDeUtil
3434
import org.apache.spark.mllib.classification._
3535
import org.apache.spark.mllib.clustering._
3636
import org.apache.spark.mllib.feature._
37+
import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
3738
import org.apache.spark.mllib.linalg._
3839
import org.apache.spark.mllib.optimization._
3940
import org.apache.spark.mllib.random.{RandomRDDs => RG}
@@ -406,6 +407,33 @@ private[python] class PythonMLLibAPI extends Serializable {
406407
new MatrixFactorizationModelWrapper(model)
407408
}
408409

410+
/**
411+
* A Wrapper of FPGrowthModel to provide helpfer method for Python
412+
*/
413+
private[python] class FPGrowthModelWrapper(model: FPGrowthModel[Any])
414+
extends FPGrowthModel(model.freqItemsets) {
415+
def getFreqItemsets: RDD[Array[Any]] = {
416+
SerDe.fromTuple2RDD(model.freqItemsets.map(x => (x.javaItems, x.freq)))
417+
}
418+
}
419+
420+
/**
421+
* Java stub for Python mllib FPGrowth.train(). This stub returns a handle
422+
* to the Java object instead of the content of the Java object. Extra care
423+
* needs to be taken in the Python code to ensure it gets freed on exit; see
424+
* the Py4J documentation.
425+
*/
426+
def trainFPGrowthModel(data: JavaRDD[java.lang.Iterable[Any]],
427+
minSupport: Double,
428+
numPartition: Int): FPGrowthModel[Any] = {
429+
val fpm = new FPGrowth()
430+
.setMinSupport(minSupport)
431+
.setNumPartitions(numPartition)
432+
433+
val model = fpm.run(data.rdd.map(_.asScala.toArray))
434+
new FPGrowthModelWrapper(model)
435+
}
436+
409437
/**
410438
* Java stub for Normalizer.transform()
411439
*/

python/pyspark/mllib/fpm.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from pyspark import SparkContext
19+
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
20+
21+
__all__ = ['FPGrowth','FPGrowthModel']
22+
23+
24+
@inherit_doc
25+
class FPGrowthModel(JavaModelWrapper):
26+
27+
"""A FP-Growth model for mining frequent itemsets using Parallel FP-Growth algorithm.
28+
29+
>>> r1 = ["r","z","h","k","p"]
30+
>>> r2 = ["z","y","x","w","v","u","t","s"]
31+
>>> r3 = ["s","x","o","n","r"]
32+
>>> r4 = ["x","z","y","m","t","s","q","e"]
33+
>>> r5 = ["z"]
34+
>>> r6 = ["x","z","y","r","q","t","p"]
35+
>>> rdd = sc.parallelize([r1,r2,r3,r4,r5,r6], 2)
36+
>>> model = FPGrowth.train(rdd, 0.5, 2)
37+
>>> result = model.freqItemsets().collect()
38+
>>> expected = [([u"s"], 3), ([u"z"], 5), ([u"x"], 4), ([u"t"], 3), ([u"y"], 3), ([u"r"],3),
39+
... ([u"x", u"z"], 3), ([u"y", u"t"], 3), ([u"t", u"x"], 3), ([u"s",u"x"], 3),
40+
... ([u"y", u"x"], 3), ([u"y", u"z"], 3), ([u"t", u"z"], 3), ([u"y", u"x", u"z"], 3),
41+
... ([u"t", u"x", u"z"], 3), ([u"y", u"t", u"z"], 3), ([u"y", u"t", u"x"], 3),
42+
... ([u"y", u"t", u"x", u"z"], 3)]
43+
>>> diff1 = [x for x in result if x not in expected]
44+
>>> len(diff1)
45+
0
46+
>>> diff2 = [x for x in expected if x not in result]
47+
>>> len(diff2)
48+
0
49+
"""
50+
def freqItemsets(self):
51+
return self.call("getFreqItemsets")
52+
53+
54+
class FPGrowth(object):
55+
56+
@classmethod
57+
def train(cls, data, minSupport=0.3, numPartition=-1):
58+
model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartition))
59+
return FPGrowthModel(model)
60+
61+
62+
def _test():
63+
import doctest
64+
import pyspark.mllib.fpm
65+
globs = pyspark.mllib.fpm.__dict__.copy()
66+
globs['sc'] = SparkContext('local[4]', 'PythonTest')
67+
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
68+
globs['sc'].stop()
69+
if failure_count:
70+
exit(-1)
71+
72+
73+
if __name__ == "__main__":
74+
_test()

python/run-tests

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ function run_mllib_tests() {
7777
run_test "pyspark/mllib/clustering.py"
7878
run_test "pyspark/mllib/evaluation.py"
7979
run_test "pyspark/mllib/feature.py"
80+
run_test "pyspark/mllib/fpm.py"
8081
run_test "pyspark/mllib/linalg.py"
8182
run_test "pyspark/mllib/rand.py"
8283
run_test "pyspark/mllib/recommendation.py"

0 commit comments

Comments
 (0)