-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-16473][MLLIB] Fix BisectingKMeans Algorithm failing in edge case #16355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
f77904d
55ab179
75192a7
14c7ce3
a8cf2f0
138ab34
ba00ad0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,9 +29,12 @@ class BisectingKMeansSuite | |
| final val k = 5 | ||
| @transient var dataset: Dataset[_] = _ | ||
|
|
||
| @transient var sparseDataset: Dataset[_] = _ | ||
|
|
||
| override def beforeAll(): Unit = { | ||
| super.beforeAll() | ||
| dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k) | ||
| sparseDataset = KMeansSuite.generateSparseData(spark, 100, 1000, k, 42) | ||
| } | ||
|
|
||
| test("default parameters") { | ||
|
|
@@ -51,6 +54,23 @@ class BisectingKMeansSuite | |
| assert(copiedModel.hasSummary) | ||
| } | ||
|
|
||
| test("SPARK-16473: Verify Bisecting K-Means does not fail in edge case where" + | ||
| "one cluster is empty after split") { | ||
| val bkm = new BisectingKMeans().setK(k).setMinDivisibleClusterSize(4).setMaxIter(4) | ||
|
||
|
|
||
| assert(bkm.getK === k) | ||
|
||
| assert(bkm.getFeaturesCol === "features") | ||
|
||
| assert(bkm.getPredictionCol === "prediction") | ||
| assert(bkm.getMaxIter === 4) | ||
| assert(bkm.getMinDivisibleClusterSize === 4) | ||
| // Verify fit does not fail on very sparse data | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not clear to me that this unit test actually tests the issue fixed in this PR. Is there a good way to see why it would? If not, then it would be great to write a tiny dataset by hand which would trigger the failure.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added this check to verify: |
||
| val model = bkm.fit(sparseDataset) | ||
| assert(model.hasSummary) | ||
| val result = model.transform(sparseDataset) | ||
| val numClusters = result.select("prediction").distinct().collect().length | ||
| assert(numClusters <= k && numClusters >= 1) | ||
| } | ||
|
|
||
| test("setter/getter") { | ||
| val bkm = new BisectingKMeans() | ||
| .setK(9) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |
|
|
||
| package org.apache.spark.ml.clustering | ||
|
|
||
| import scala.util.Random | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.ml.linalg.{Vector, Vectors} | ||
| import org.apache.spark.ml.param.ParamMap | ||
|
|
@@ -160,6 +162,17 @@ object KMeansSuite { | |
| spark.createDataFrame(rdd) | ||
| } | ||
|
|
||
| def generateSparseData(spark: SparkSession, rows: Int, dim: Int, k: Int, seed: Int): DataFrame = { | ||
|
||
| val sc = spark.sparkContext | ||
| val random = new Random(seed) | ||
| val nnz = random.nextInt(dim) | ||
| val rdd = sc.parallelize(1 to rows) | ||
| .map(i => Vectors.sparse(dim, random.shuffle(0 to dim - 1).slice(0, nnz).sorted.toArray, | ||
| Array.fill(nnz)(random.nextDouble()))) | ||
| .map(v => new TestRow(v)) | ||
| spark.createDataFrame(rdd) | ||
| } | ||
|
|
||
| /** | ||
| * Mapping from all Params to valid settings which differ from the defaults. | ||
| * This is useful for tests which need to exercise all Params, such as save/load. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the test really need to be this large? It takes ~1 sec which is long-ish for a unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can only repro the issue with very sparse data where number of columns is around 1k. I was able to reduce the number of rows to be only 10 however. I hope that is a small enough dataset.