|
| 1 | +# |
| 2 | +# Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | +# contributor license agreements. See the NOTICE file distributed with |
| 4 | +# this work for additional information regarding copyright ownership. |
| 5 | +# The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | +# (the "License"); you may not use this file except in compliance with |
| 7 | +# the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +from pyspark import SparkContext |
| 19 | +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc |
| 20 | + |
| 21 | +__all__ = ['FPGrowth','FPGrowthModel'] |
| 22 | + |
| 23 | + |
| 24 | +@inherit_doc |
| 25 | +class FPGrowthModel(JavaModelWrapper): |
| 26 | + |
| 27 | + """A FP-Growth model for mining frequent itemsets using Parallel FP-Growth algorithm. |
| 28 | +
|
| 29 | + >>> r1 = ["r","z","h","k","p"] |
| 30 | + >>> r2 = ["z","y","x","w","v","u","t","s"] |
| 31 | + >>> r3 = ["s","x","o","n","r"] |
| 32 | + >>> r4 = ["x","z","y","m","t","s","q","e"] |
| 33 | + >>> r5 = ["z"] |
| 34 | + >>> r6 = ["x","z","y","r","q","t","p"] |
| 35 | + >>> rdd = sc.parallelize([r1,r2,r3,r4,r5,r6], 2) |
| 36 | + >>> model = FPGrowth.train(rdd, 0.5, 2) |
| 37 | + >>> result = model.freqItemsets().collect() |
| 38 | + >>> expected = [([u"s"], 3), ([u"z"], 5), ([u"x"], 4), ([u"t"], 3), ([u"y"], 3), ([u"r"],3), |
| 39 | + ... ([u"x", u"z"], 3), ([u"y", u"t"], 3), ([u"t", u"x"], 3), ([u"s",u"x"], 3), |
| 40 | + ... ([u"y", u"x"], 3), ([u"y", u"z"], 3), ([u"t", u"z"], 3), ([u"y", u"x", u"z"], 3), |
| 41 | + ... ([u"t", u"x", u"z"], 3), ([u"y", u"t", u"z"], 3), ([u"y", u"t", u"x"], 3), |
| 42 | + ... ([u"y", u"t", u"x", u"z"], 3)] |
| 43 | + >>> diff1 = [x for x in result if x not in expected] |
| 44 | + >>> len(diff1) |
| 45 | + 0 |
| 46 | + >>> diff2 = [x for x in expected if x not in result] |
| 47 | + >>> len(diff2) |
| 48 | + 0 |
| 49 | + """ |
| 50 | + def freqItemsets(self): |
| 51 | + return self.call("getFreqItemsets") |
| 52 | + |
| 53 | + |
| 54 | +class FPGrowth(object): |
| 55 | + |
| 56 | + @classmethod |
| 57 | + def train(cls, data, minSupport=0.3, numPartition=-1): |
| 58 | + model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartition)) |
| 59 | + return FPGrowthModel(model) |
| 60 | + |
| 61 | + |
| 62 | +def _test(): |
| 63 | + import doctest |
| 64 | + import pyspark.mllib.fpm |
| 65 | + globs = pyspark.mllib.fpm.__dict__.copy() |
| 66 | + globs['sc'] = SparkContext('local[4]', 'PythonTest') |
| 67 | + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) |
| 68 | + globs['sc'].stop() |
| 69 | + if failure_count: |
| 70 | + exit(-1) |
| 71 | + |
| 72 | + |
| 73 | +if __name__ == "__main__": |
| 74 | + _test() |
0 commit comments