Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 8 additions & 30 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.storage.StorageLevel
* :: Experimental ::
*
* Model trained by [[FPGrowth]], which holds frequent itemsets.
* @param freqItemsets frequent itemsets, which is an RDD of [[FreqItemset]]
* @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]]
* @tparam Item item type
*/
@Experimental
Expand All @@ -62,14 +62,13 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex
@Experimental
class FPGrowth private (
private var minSupport: Double,
private var numPartitions: Int,
private var ordered: Boolean) extends Logging with Serializable {
private var numPartitions: Int) extends Logging with Serializable {

/**
* Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same
* as the input data, ordered: `false`}.
* as the input data}.
*/
def this() = this(0.3, -1, false)
def this() = this(0.3, -1)

/**
* Sets the minimal support level (default: `0.3`).
Expand All @@ -87,15 +86,6 @@ class FPGrowth private (
this
}

/**
* Indicates whether to mine itemsets (unordered) or sequences (ordered) (default: false, mine
* itemsets).
*/
def setOrdered(ordered: Boolean): this.type = {
this.ordered = ordered
this
}

/**
* Computes an FP-Growth model that contains frequent itemsets.
* @param data input data set, each element contains a transaction
Expand Down Expand Up @@ -165,7 +155,7 @@ class FPGrowth private (
.flatMap { case (part, tree) =>
tree.extract(minCount, x => partitioner.getPartition(x) == part)
}.map { case (ranks, count) =>
new FreqItemset(ranks.map(i => freqItems(i)).reverse.toArray, count, ordered)
new FreqItemset(ranks.map(i => freqItems(i)).toArray, count)
}
}

Expand All @@ -181,12 +171,9 @@ class FPGrowth private (
itemToRank: Map[Item, Int],
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
val output = mutable.Map.empty[Int, Array[Int]]
// Filter the basket by frequent items pattern
// Filter the basket by frequent items pattern and sort their ranks.
val filtered = transaction.flatMap(itemToRank.get)
if (!this.ordered) {
ju.Arrays.sort(filtered)
}
// Generate conditional transactions
ju.Arrays.sort(filtered)
val n = filtered.length
var i = n - 1
while (i >= 0) {
Expand All @@ -211,18 +198,9 @@ object FPGrowth {
* Frequent itemset.
* @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead.
* @param freq frequency
* @param ordered indicates if items represents an itemset (false) or sequence (true)
* @tparam Item item type
*/
class FreqItemset[Item](val items: Array[Item], val freq: Long, val ordered: Boolean)
extends Serializable {

/**
* Auxillary constructor, assumes unordered by default.
*/
def this(items: Array[Item], freq: Long) {
this(items, freq, false)
}
class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable {

/**
* Returns items in a Java List.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {


test("FP-Growth frequent itemsets using String type") {
test("FP-Growth using String type") {
val transactions = Seq(
"r z h k p",
"z y x w v u t s",
Expand All @@ -38,14 +38,12 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
val model6 = fpg
.setMinSupport(0.9)
.setNumPartitions(1)
.setOrdered(false)
.run(rdd)
assert(model6.freqItemsets.count() === 0)

val model3 = fpg
.setMinSupport(0.5)
.setNumPartitions(2)
.setOrdered(false)
.run(rdd)
val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
(itemset.items.toSet, itemset.freq)
Expand All @@ -63,59 +61,17 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
val model2 = fpg
.setMinSupport(0.3)
.setNumPartitions(4)
.setOrdered(false)
.run(rdd)
assert(model2.freqItemsets.count() === 54)

val model1 = fpg
.setMinSupport(0.1)
.setNumPartitions(8)
.setOrdered(false)
.run(rdd)
assert(model1.freqItemsets.count() === 625)
}

test("FP-Growth frequent sequences using String type"){
val transactions = Seq(
"r z h k p",
"z y x w v u t s",
"s x o n r",
"x z y m t s q e",
"z",
"x z y r q t p")
.map(_.split(" "))
val rdd = sc.parallelize(transactions, 2).cache()

val fpg = new FPGrowth()

val model1 = fpg
.setMinSupport(0.5)
.setNumPartitions(2)
.setOrdered(true)
.run(rdd)

/*
Use the following R code to verify association rules using arulesSequences package.

data = read_baskets("path", info = c("sequenceID","eventID","SIZE"))
freqItemSeq = cspade(data, parameter = list(support = 0.5))
resSeq = as(freqItemSeq, "data.frame")
resSeq$support = resSeq$support * length(transactions)
names(resSeq)[names(resSeq) == "support"] = "freq"
resSeq
*/
val expected = Set(
(Seq("r"), 3L), (Seq("s"), 3L), (Seq("t"), 3L), (Seq("x"), 4L), (Seq("y"), 3L),
(Seq("z"), 5L), (Seq("z", "y"), 3L), (Seq("x", "t"), 3L), (Seq("y", "t"), 3L),
(Seq("z", "t"), 3L), (Seq("z", "y", "t"), 3L)
)
val freqItemseqs1 = model1.freqItemsets.collect().map { itemset =>
(itemset.items.toSeq, itemset.freq)
}.toSet
assert(freqItemseqs1 == expected)
}

test("FP-Growth frequent itemsets using Int type") {
test("FP-Growth using Int type") {
val transactions = Seq(
"1 2 3",
"1 2 3 4",
Expand All @@ -132,14 +88,12 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
val model6 = fpg
.setMinSupport(0.9)
.setNumPartitions(1)
.setOrdered(false)
.run(rdd)
assert(model6.freqItemsets.count() === 0)

val model3 = fpg
.setMinSupport(0.5)
.setNumPartitions(2)
.setOrdered(false)
.run(rdd)
assert(model3.freqItemsets.first().items.getClass === Array(1).getClass,
"frequent itemsets should use primitive arrays")
Expand All @@ -155,14 +109,12 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
val model2 = fpg
.setMinSupport(0.3)
.setNumPartitions(4)
.setOrdered(false)
.run(rdd)
assert(model2.freqItemsets.count() === 15)

val model1 = fpg
.setMinSupport(0.1)
.setNumPartitions(8)
.setOrdered(false)
.run(rdd)
assert(model1.freqItemsets.count() === 65)
}
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/mllib/fpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class FPGrowthModel(JavaModelWrapper):
>>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]]
>>> rdd = sc.parallelize(data, 2)
>>> model = FPGrowth.train(rdd, 0.6, 2)
>>> sorted(model.freqItemsets().collect(), key=lambda x: x.items)
[FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'a', u'c'], freq=3), ...
>>> sorted(model.freqItemsets().collect())
[FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ...
"""

def freqItemsets(self):
Expand Down