-
Notifications
You must be signed in to change notification settings - Fork 29.2k
[SPARK-20003] [ML] FPGrowthModel setMinConfidence should affect rules generation and transform #17336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-20003] [ML] FPGrowthModel setMinConfidence should affect rules generation and transform #17336
Changes from all commits
3398d62
f761ffd
9c046c3
d81fb2f
a95a07a
5ef84f1
81bce96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
| package org.apache.spark.ml.fpm | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.ml.util.DefaultReadWriteTest | ||
| import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} | ||
| import org.apache.spark.sql.functions._ | ||
|
|
@@ -85,38 +85,58 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul | |
| assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty) | ||
| } | ||
|
|
||
| test("FPGrowth prediction should not contain duplicates") { | ||
| // This should generate rule 1 -> 3, 2 -> 3 | ||
| val dataset = spark.createDataFrame(Seq( | ||
| Array("1", "3"), | ||
| Array("2", "3") | ||
| ).map(Tuple1(_))).toDF("items") | ||
| val model = new FPGrowth().fit(dataset) | ||
|
|
||
| val prediction = model.transform( | ||
| spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items") | ||
| ).first().getAs[Seq[String]]("prediction") | ||
|
|
||
| assert(prediction === Seq("3")) | ||
| } | ||
|
|
||
| test("FPGrowthModel setMinConfidence should affect rules generation and transform") { | ||
| val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset) | ||
| val oldRulesNum = model.associationRules.count() | ||
| val oldPredict = model.transform(dataset) | ||
|
|
||
| model.setMinConfidence(0.8765) | ||
| assert(oldRulesNum > model.associationRules.count()) | ||
| assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet)) | ||
|
|
||
| // association rules should stay the same for same minConfidence | ||
| model.setMinConfidence(0.1) | ||
| assert(oldRulesNum === model.associationRules.count()) | ||
| assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet)) | ||
| } | ||
|
|
||
| test("FPGrowth parameter check") { | ||
| val fpGrowth = new FPGrowth().setMinSupport(0.4567) | ||
| val model = fpGrowth.fit(dataset) | ||
| .setMinConfidence(0.5678) | ||
| assert(fpGrowth.getMinSupport === 0.4567) | ||
| assert(model.getMinConfidence === 0.5678) | ||
| MLTestingUtils.checkCopy(model) | ||
| } | ||
|
|
||
| test("read/write") { | ||
| def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = { | ||
| assert(model.freqItemsets.sort("items").collect() === | ||
| model2.freqItemsets.sort("items").collect()) | ||
| assert(model.freqItemsets.collect().toSet.equals( | ||
| model2.freqItemsets.collect().toSet)) | ||
| assert(model.associationRules.collect().toSet.equals( | ||
| model2.associationRules.collect().toSet)) | ||
| assert(model.setMinConfidence(0.9).associationRules.collect().toSet.equals( | ||
| model2.setMinConfidence(0.9).associationRules.collect().toSet)) | ||
| } | ||
| val fPGrowth = new FPGrowth() | ||
| testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings, | ||
| FPGrowthSuite.allParamSettings, checkModelData) | ||
| } | ||
|
|
||
| test("FPGrowth prediction should not contain duplicates") { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the future, I'd prefer not to move stuff around unless it's necessary since it makes the diff larger. No need to revert this, though, since I already checked it.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. |
||
| // This should generate rule 1 -> 3, 2 -> 3 | ||
| val dataset = spark.createDataFrame(Seq( | ||
| Array("1", "3"), | ||
| Array("2", "3") | ||
| ).map(Tuple1(_))).toDF("items") | ||
| val model = new FPGrowth().fit(dataset) | ||
|
|
||
| val prediction = model.transform( | ||
| spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items") | ||
| ).first().getAs[Seq[String]]("prediction") | ||
|
|
||
| assert(prediction === Seq("3")) | ||
| } | ||
| } | ||
|
|
||
| object FPGrowthSuite { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to add these 2 since they are values computed from the model data. Checking freqItemsets is sufficient.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for comment. I added the check since we added some internal cache fields, I'd like to ensure it does not interfere with the model loading. Let me know it is still redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough. Let's keep it