Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.api.python

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
import org.apache.spark.rdd.RDD

/**
* A Wrapper of FPGrowthModel to provide helper method for Python
*/
private[python] class FPGrowthModelWrapper(model: FPGrowthModel[Any])
extends FPGrowthModel(model.freqItemsets) {

def getFreqItemsets: RDD[Array[Any]] = {
SerDe.fromTuple2RDD(model.freqItemsets.map(x => (x.javaItems, x.freq)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.feature._
import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
Expand Down Expand Up @@ -344,9 +345,7 @@ private[python] class PythonMLLibAPI extends Serializable {
val model = new GaussianMixtureModel(weight, gaussians)
model.predictSoft(data)
}




/**
* Java stub for Python mllib ALS.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
Expand Down Expand Up @@ -406,6 +405,24 @@ private[python] class PythonMLLibAPI extends Serializable {
new MatrixFactorizationModelWrapper(model)
}

/**
* Java stub for Python mllib FPGrowth.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
* needs to be taken in the Python code to ensure it gets freed on exit; see
* the Py4J documentation.
*/
def trainFPGrowthModel(
data: JavaRDD[java.lang.Iterable[Any]],
minSupport: Double,
numPartitions: Int): FPGrowthModel[Any] = {
val fpg = new FPGrowth()
.setMinSupport(minSupport)
.setNumPartitions(numPartitions)

val model = fpg.run(data.rdd.map(_.asScala.toArray))
new FPGrowthModelWrapper(model)
}

/**
* Java stub for Normalizer.transform()
*/
Expand Down
7 changes: 7 additions & 0 deletions python/docs/pyspark.mllib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ pyspark.mllib.feature module
:undoc-members:
:show-inheritance:

pyspark.mllib.fpm module
------------------------

.. automodule:: pyspark.mllib.fpm
:members:
:undoc-members:

pyspark.mllib.linalg module
---------------------------

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/mllib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
if numpy.version.version < '1.4':
raise Exception("MLlib requires NumPy 1.4+")

__all__ = ['classification', 'clustering', 'feature', 'linalg', 'random',
__all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random',
'recommendation', 'regression', 'stat', 'tree', 'util']

import sys
Expand Down
67 changes: 67 additions & 0 deletions python/pyspark/mllib/fpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark import SparkContext
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc

__all__ = ['FPGrowth', 'FPGrowthModel']


@inherit_doc
class FPGrowthModel(JavaModelWrapper):

"""A FP-Growth model for mining frequent itemsets using the Parallel FP-Growth algorithm.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Python doc, we limit the line width to 72 (following PEP8). This doesn't include the code example in the doc. Please update the doc strings in your PR.


>>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]]
>>> rdd = sc.parallelize(data, 2)
>>> model = FPGrowth.train(rdd, 0.6, 2)
>>> result = model.freqItemsets().collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use

>>> sorted(model.freqItemsets().collect())

and put the results as expected output to verify.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should increase the threshold to make the expected output shorter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this line

>>> sorted(model.freqItemsets().collect())
[([u'a'], 4), ([u'c'], 3), ([u'c', u'a'], 3)]
"""
def freqItemsets(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty line before this line and doc are needed. It might be convenient if we follow the Java/Scala implementation and use a namedtuple to wrap the result. So users can call items and freq instead of [0] and [1].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an blank line before def .. and add doc to this function.

return self.call("getFreqItemsets")


class FPGrowth(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add doc


@classmethod
def train(cls, data, minSupport=0.3, numPartitions=-1):
"""
Computes an FP-Growth model that contains frequent itemsets.
:param data: The input data set, each element contains a transaction.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line too wide

:param minSupport: The minimal support level (default: `0.3`).
:param numPartitions: The number of partitions used by parallel FP-growth
(default: same as input data).
"""
model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions))
return FPGrowthModel(model)


def _test():
import doctest
import pyspark.mllib.fpm
globs = pyspark.mllib.fpm.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest')
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)


if __name__ == "__main__":
_test()
1 change: 1 addition & 0 deletions python/run-tests
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ function run_mllib_tests() {
run_test "pyspark/mllib/clustering.py"
run_test "pyspark/mllib/evaluation.py"
run_test "pyspark/mllib/feature.py"
run_test "pyspark/mllib/fpm.py"
run_test "pyspark/mllib/linalg.py"
run_test "pyspark/mllib/rand.py"
run_test "pyspark/mllib/recommendation.py"
Expand Down