1717
1818package org .apache .spark .mllib .fpm
1919
20+ import java .lang .{Iterable => JavaIterable }
2021import java .{util => ju }
2122
23+ import scala .collection .JavaConverters ._
2224import scala .collection .mutable
25+ import scala .reflect .ClassTag
2326
24- import org .apache .spark .{ SparkException , HashPartitioner , Logging , Partitioner }
27+ import org .apache .spark .api . java . JavaRDD
2528import org .apache .spark .rdd .RDD
2629import org .apache .spark .storage .StorageLevel
30+ import org .apache .spark .{HashPartitioner , Logging , Partitioner , SparkException }
2731
28- class FPGrowthModel (val freqItemsets : RDD [(Array [String ], Long )]) extends Serializable
32+ class FPGrowthModel [Item ](val freqItemsets : RDD [(Array [Item ], Long )]) extends Serializable {
33+ def javaFreqItemsets (): JavaRDD [(Array [Item ], Long )] = {
34+ freqItemsets.toJavaRDD()
35+ }
36+ }
2937
3038/**
3139 * This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
@@ -69,32 +77,36 @@ class FPGrowth private (
6977 * @param data input data set, each element contains a transaction
7078 * @return an [[FPGrowthModel ]]
7179 */
72- def run (data : RDD [Array [ String ]] ): FPGrowthModel = {
80+ def run [ Item : ClassTag , Basket <: Iterable [ Item ]] (data : RDD [Basket ] ): FPGrowthModel [ Item ] = {
7381 if (data.getStorageLevel == StorageLevel .NONE ) {
7482 logWarning(" Input data is not cached." )
7583 }
7684 val count = data.count()
7785 val minCount = math.ceil(minSupport * count).toLong
7886 val numParts = if (numPartitions > 0 ) numPartitions else data.partitions.length
7987 val partitioner = new HashPartitioner (numParts)
80- val freqItems = genFreqItems(data, minCount, partitioner)
81- val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
88+ val freqItems = genFreqItems[ Item , Basket ] (data, minCount, partitioner)
89+ val freqItemsets = genFreqItemsets[ Item , Basket ] (data, minCount, freqItems, partitioner)
8290 new FPGrowthModel (freqItemsets)
8391 }
8492
93+ def run [Item : ClassTag , Basket <: JavaIterable [Item ]](data : JavaRDD [Basket ]): FPGrowthModel [Item ] = {
94+ this .run(data.rdd.map(_.asScala))
95+ }
96+
8597 /**
8698 * Generates frequent items by filtering the input data using minimal support level.
8799 * @param minCount minimum count for frequent itemsets
88100 * @param partitioner partitioner used to distribute items
89101 * @return array of frequent pattern ordered by their frequencies
90102 */
91- private def genFreqItems (
92- data : RDD [Array [ String ] ],
103+ private def genFreqItems [ Item : ClassTag , Basket <: Iterable [ Item ]] (
104+ data : RDD [Basket ],
93105 minCount : Long ,
94- partitioner : Partitioner ): Array [String ] = {
106+ partitioner : Partitioner ): Array [Item ] = {
95107 data.flatMap { t =>
96108 val uniq = t.toSet
97- if (t.length != uniq.size) {
109+ if (t.size != uniq.size) {
98110 throw new SparkException (s " Items in a transaction must be unique but got ${t.toSeq}. " )
99111 }
100112 t
@@ -114,11 +126,11 @@ class FPGrowth private (
114126 * @param partitioner partitioner used to distribute transactions
115127 * @return an RDD of (frequent itemset, count)
116128 */
117- private def genFreqItemsets (
118- data : RDD [Array [ String ] ],
129+ private def genFreqItemsets [ Item : ClassTag , Basket <: Iterable [ Item ]] (
130+ data : RDD [Basket ],
119131 minCount : Long ,
120- freqItems : Array [String ],
121- partitioner : Partitioner ): RDD [(Array [String ], Long )] = {
132+ freqItems : Array [Item ],
133+ partitioner : Partitioner ): RDD [(Array [Item ], Long )] = {
122134 val itemToRank = freqItems.zipWithIndex.toMap
123135 data.flatMap { transaction =>
124136 genCondTransactions(transaction, itemToRank, partitioner)
@@ -139,13 +151,13 @@ class FPGrowth private (
139151 * @param partitioner partitioner used to distribute transactions
140152 * @return a map of (target partition, conditional transaction)
141153 */
142- private def genCondTransactions (
143- transaction : Array [ String ] ,
144- itemToRank : Map [String , Int ],
154+ private def genCondTransactions [ Item : ClassTag , Basket <: Iterable [ Item ]] (
155+ transaction : Basket ,
156+ itemToRank : Map [Item , Int ],
145157 partitioner : Partitioner ): mutable.Map [Int , Array [Int ]] = {
146158 val output = mutable.Map .empty[Int , Array [Int ]]
147159 // Filter the basket by frequent items pattern and sort their ranks.
148- val filtered = transaction.flatMap(itemToRank.get)
160+ val filtered = transaction.flatMap(itemToRank.get).toArray
149161 ju.Arrays .sort(filtered)
150162 val n = filtered.length
151163 var i = n - 1
0 commit comments