Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,28 @@ class FPGrowthModel private[ml] (
def setPredictionCol(value: String): this.type = set(predictionCol, value)

/**
* Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe
* Cache minConfidence and associationRules to avoid redundant computation for association rules
* during transform. The associationRules will only be re-computed when minConfidence changed.
*/
@transient private var _cachedMinConf: Double = Double.NaN

@transient private var _cachedRules: DataFrame = _

/**
* Get association rules fitted using the minConfidence. Returns a dataframe
* with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and
* "consequent" are Array[T] and "confidence" is Double.
*/
@Since("2.2.0")
@transient lazy val associationRules: DataFrame = {
AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
@transient def associationRules: DataFrame = {
if ($(minConfidence) == _cachedMinConf) {
_cachedRules
} else {
_cachedRules = AssociationRules
.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
_cachedMinConf = $(minConfidence)
_cachedRules
}
}

/**
Expand Down
56 changes: 38 additions & 18 deletions mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.spark.ml.fpm

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -85,38 +85,58 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
}

test("FPGrowth prediction should not contain duplicates") {
// This should generate rule 1 -> 3, 2 -> 3
val dataset = spark.createDataFrame(Seq(
Array("1", "3"),
Array("2", "3")
).map(Tuple1(_))).toDF("items")
val model = new FPGrowth().fit(dataset)

val prediction = model.transform(
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
).first().getAs[Seq[String]]("prediction")

assert(prediction === Seq("3"))
}

test("FPGrowthModel setMinConfidence should affect rules generation and transform") {
val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset)
val oldRulesNum = model.associationRules.count()
val oldPredict = model.transform(dataset)

model.setMinConfidence(0.8765)
assert(oldRulesNum > model.associationRules.count())
assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))

// association rules should stay the same for same minConfidence
model.setMinConfidence(0.1)
assert(oldRulesNum === model.associationRules.count())
assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
}

test("FPGrowth parameter check") {
val fpGrowth = new FPGrowth().setMinSupport(0.4567)
val model = fpGrowth.fit(dataset)
.setMinConfidence(0.5678)
assert(fpGrowth.getMinSupport === 0.4567)
assert(model.getMinConfidence === 0.5678)
MLTestingUtils.checkCopy(model)
}

test("read/write") {
def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = {
assert(model.freqItemsets.sort("items").collect() ===
model2.freqItemsets.sort("items").collect())
assert(model.freqItemsets.collect().toSet.equals(
model2.freqItemsets.collect().toSet))
assert(model.associationRules.collect().toSet.equals(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to add these 2 since they are values computed from the model data. Checking freqItemsets is sufficient.

Copy link
Copy Markdown
Contributor Author

@hhbyyh hhbyyh Apr 4, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for comment. I added the check since we added some internal cache fields, I'd like to ensure it does not interfere with the model loading. Let me know it is still redundant.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. Let's keep it

model2.associationRules.collect().toSet))
assert(model.setMinConfidence(0.9).associationRules.collect().toSet.equals(
model2.setMinConfidence(0.9).associationRules.collect().toSet))
}
val fPGrowth = new FPGrowth()
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
FPGrowthSuite.allParamSettings, checkModelData)
}

test("FPGrowth prediction should not contain duplicates") {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the future, I'd prefer not to move stuff around unless it's necessary since it makes the diff larger. No need to revert this, though, since I already checked it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

// This should generate rule 1 -> 3, 2 -> 3
val dataset = spark.createDataFrame(Seq(
Array("1", "3"),
Array("2", "3")
).map(Tuple1(_))).toDF("items")
val model = new FPGrowth().fit(dataset)

val prediction = model.transform(
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
).first().getAs[Seq[String]]("prediction")

assert(prediction === Seq("3"))
}
}

object FPGrowthSuite {
Expand Down