Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,25 @@ class FPGrowthModel private[ml] (
@Since("2.2.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)

@transient private var _cachedMinConf: Double = Double.NaN

@transient private var _cachedRules: DataFrame = null

/**
* Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe
* with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and
* "consequent" are Array[T] and "confidence" is Double.
*/
@Since("2.2.0")
@transient lazy val associationRules: DataFrame = {
AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
@transient def associationRules: DataFrame = {
if ($(minConfidence) == _cachedMinConf) {
_cachedRules
} else {
_cachedRules = AssociationRules
.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
_cachedMinConf = $(minConfidence)
_cachedRules
}
}

/**
Expand Down
53 changes: 36 additions & 17 deletions mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,36 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
}

test("FPGrowth prediction should not contain duplicates") {
// This should generate rule 1 -> 3, 2 -> 3
val dataset = spark.createDataFrame(Seq(
Array("1", "3"),
Array("2", "3")
).map(Tuple1(_))).toDF("features")
val model = new FPGrowth().fit(dataset)

val prediction = model.transform(
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features")
).first().getAs[Seq[String]]("prediction")

assert(prediction === Seq("3"))
}

test("FPGrowthModel setMinConfidence should affect rules generation and transform") {
val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset)
val oldRulesNum = model.associationRules.count()
assert(oldRulesNum == model.associationRules.count())
val oldPredict = model.transform(dataset)

model.setMinConfidence(0.1)
assert(oldRulesNum === model.associationRules.count())
assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))

model.setMinConfidence(0.8765)
assert(oldRulesNum > model.associationRules.count())
assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
}

test("FPGrowth parameter check") {
val fpGrowth = new FPGrowth().setMinSupport(0.4567)
val model = fpGrowth.fit(dataset)
Expand All @@ -95,28 +125,17 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul

test("read/write") {
def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = {
assert(model.freqItemsets.sort("items").collect() ===
model2.freqItemsets.sort("items").collect())
assert(model.freqItemsets.collect().toSet.equals(
model2.freqItemsets.collect().toSet))
assert(model.associationRules.collect().toSet.equals(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to add these 2 since they are values computed from the model data. Checking freqItemsets is sufficient.

@hhbyyh hhbyyh Apr 4, 2017

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for comment. I added the check since we added some internal cache fields, I'd like to ensure it does not interfere with the model loading. Let me know it is still redundant.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. Let's keep it

model2.associationRules.collect().toSet))
assert(model.setMinConfidence(0.9).associationRules.collect().toSet.equals(
model2.setMinConfidence(0.9).associationRules.collect().toSet))
}
val fPGrowth = new FPGrowth()
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
FPGrowthSuite.allParamSettings, checkModelData)
}

test("FPGrowth prediction should not contain duplicates") {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the future, I'd prefer not to move stuff around unless it's necessary since it makes the diff larger. No need to revert this, though, since I already checked it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

// This should generate rule 1 -> 3, 2 -> 3
val dataset = spark.createDataFrame(Seq(
Array("1", "3"),
Array("2", "3")
).map(Tuple1(_))).toDF("features")
val model = new FPGrowth().fit(dataset)

val prediction = model.transform(
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features")
).first().getAs[Seq[String]]("prediction")

assert(prediction === Seq("3"))
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't change this one, just move it to keep parameter and save/load check at the bottom.

}

object FPGrowthSuite {
Expand Down