Skip to content

Commit a0411ae

Browse files
yanboliangmengxr
authored andcommitted
[SPARK-6264] [MLLIB] Support FPGrowth algorithm in Python API
Support FPGrowth algorithm in Python API. Should we remove "Experimental" which were marked for FPGrowth and FPGrowthModel in Scala? jkbradley Author: Yanbo Liang <[email protected]> Closes #5213 from yanboliang/spark-6264 and squashes the following commits: ed62ead [Yanbo Liang] trigger jenkins 8ce0359 [Yanbo Liang] fix docstring style 544c725 [Yanbo Liang] address comments a2d7cf7 [Yanbo Liang] add doc for FPGrowth.train() dcf7d73 [Yanbo Liang] add python doc b18fd07 [Yanbo Liang] trigger jenkins 2c951b8 [Yanbo Liang] fix typos 7f62c8f [Yanbo Liang] add fpm to __init__.py b96206a [Yanbo Liang] Support FPGrowth algorithm in Python API
1 parent 7d92db3 commit a0411ae

File tree

6 files changed

+143
-4
lines changed

6 files changed

+143
-4
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.api.python
19+
20+
import org.apache.spark.api.java.JavaRDD
21+
import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
22+
import org.apache.spark.rdd.RDD
23+
24+
/**
25+
* A Wrapper of FPGrowthModel to provide helper method for Python
26+
*/
27+
private[python] class FPGrowthModelWrapper(model: FPGrowthModel[Any])
28+
extends FPGrowthModel(model.freqItemsets) {
29+
30+
def getFreqItemsets: RDD[Array[Any]] = {
31+
SerDe.fromTuple2RDD(model.freqItemsets.map(x => (x.javaItems, x.freq)))
32+
}
33+
}

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.api.python.SerDeUtil
3434
import org.apache.spark.mllib.classification._
3535
import org.apache.spark.mllib.clustering._
3636
import org.apache.spark.mllib.feature._
37+
import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
3738
import org.apache.spark.mllib.linalg._
3839
import org.apache.spark.mllib.optimization._
3940
import org.apache.spark.mllib.random.{RandomRDDs => RG}
@@ -358,9 +359,7 @@ private[python] class PythonMLLibAPI extends Serializable {
358359
val model = new GaussianMixtureModel(weight, gaussians)
359360
model.predictSoft(data)
360361
}
361-
362-
363-
362+
364363
/**
365364
* Java stub for Python mllib ALS.train(). This stub returns a handle
366365
* to the Java object instead of the content of the Java object. Extra care
@@ -420,6 +419,24 @@ private[python] class PythonMLLibAPI extends Serializable {
420419
new MatrixFactorizationModelWrapper(model)
421420
}
422421

422+
/**
423+
* Java stub for Python mllib FPGrowth.train(). This stub returns a handle
424+
* to the Java object instead of the content of the Java object. Extra care
425+
* needs to be taken in the Python code to ensure it gets freed on exit; see
426+
* the Py4J documentation.
427+
*/
428+
def trainFPGrowthModel(
429+
data: JavaRDD[java.lang.Iterable[Any]],
430+
minSupport: Double,
431+
numPartitions: Int): FPGrowthModel[Any] = {
432+
val fpg = new FPGrowth()
433+
.setMinSupport(minSupport)
434+
.setNumPartitions(numPartitions)
435+
436+
val model = fpg.run(data.rdd.map(_.asScala.toArray))
437+
new FPGrowthModelWrapper(model)
438+
}
439+
423440
/**
424441
* Java stub for Normalizer.transform()
425442
*/

python/docs/pyspark.mllib.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ pyspark.mllib.feature module
3131
:undoc-members:
3232
:show-inheritance:
3333

34+
pyspark.mllib.fpm module
35+
------------------------
36+
37+
.. automodule:: pyspark.mllib.fpm
38+
:members:
39+
:undoc-members:
40+
3441
pyspark.mllib.linalg module
3542
---------------------------
3643

python/pyspark/mllib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
if numpy.version.version < '1.4':
2626
raise Exception("MLlib requires NumPy 1.4+")
2727

28-
__all__ = ['classification', 'clustering', 'feature', 'linalg', 'random',
28+
__all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random',
2929
'recommendation', 'regression', 'stat', 'tree', 'util']
3030

3131
import sys

python/pyspark/mllib/fpm.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from pyspark import SparkContext
19+
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
20+
21+
__all__ = ['FPGrowth', 'FPGrowthModel']
22+
23+
24+
@inherit_doc
25+
class FPGrowthModel(JavaModelWrapper):
26+
27+
"""
28+
.. note:: Experimental
29+
30+
A FP-Growth model for mining frequent itemsets
31+
using the Parallel FP-Growth algorithm.
32+
33+
>>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]]
34+
>>> rdd = sc.parallelize(data, 2)
35+
>>> model = FPGrowth.train(rdd, 0.6, 2)
36+
>>> sorted(model.freqItemsets().collect())
37+
[([u'a'], 4), ([u'c'], 3), ([u'c', u'a'], 3)]
38+
"""
39+
40+
def freqItemsets(self):
41+
"""
42+
Get the frequent itemsets of this model
43+
"""
44+
return self.call("getFreqItemsets")
45+
46+
47+
class FPGrowth(object):
48+
"""
49+
.. note:: Experimental
50+
51+
A Parallel FP-growth algorithm to mine frequent itemsets.
52+
"""
53+
54+
@classmethod
55+
def train(cls, data, minSupport=0.3, numPartitions=-1):
56+
"""
57+
Computes an FP-Growth model that contains frequent itemsets.
58+
:param data: The input data set, each element
59+
contains a transaction.
60+
:param minSupport: The minimal support level
61+
(default: `0.3`).
62+
:param numPartitions: The number of partitions used by parallel
63+
FP-growth (default: same as input data).
64+
"""
65+
model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions))
66+
return FPGrowthModel(model)
67+
68+
69+
def _test():
70+
import doctest
71+
import pyspark.mllib.fpm
72+
globs = pyspark.mllib.fpm.__dict__.copy()
73+
globs['sc'] = SparkContext('local[4]', 'PythonTest')
74+
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
75+
globs['sc'].stop()
76+
if failure_count:
77+
exit(-1)
78+
79+
80+
if __name__ == "__main__":
81+
_test()

python/run-tests

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ function run_mllib_tests() {
7777
run_test "pyspark/mllib/clustering.py"
7878
run_test "pyspark/mllib/evaluation.py"
7979
run_test "pyspark/mllib/feature.py"
80+
run_test "pyspark/mllib/fpm.py"
8081
run_test "pyspark/mllib/linalg.py"
8182
run_test "pyspark/mllib/rand.py"
8283
run_test "pyspark/mllib/recommendation.py"

0 commit comments

Comments
 (0)