Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,15 @@ private object BisectingKMeans extends Serializable {
assignments.map { case (index, v) =>
if (divisibleIndices.contains(index)) {
val children = Seq(leftChildIndex(index), rightChildIndex(index))
val selected = children.minBy { child =>
KMeans.fastSquaredDistance(newClusterCenters(child), v)
val newClusterChildren = children.filter(newClusterCenters.contains(_))
if (newClusterChildren.nonEmpty) {
val selected = newClusterChildren.minBy { child =>
KMeans.fastSquaredDistance(newClusterCenters(child), v)
}
(selected, v)
} else {
(index, v)
}
(selected, v)
} else {
(index, v)
}
Expand Down Expand Up @@ -372,12 +377,12 @@ private object BisectingKMeans extends Serializable {
internalIndex -= 1
val leftIndex = leftChildIndex(rawIndex)
val rightIndex = rightChildIndex(rawIndex)
val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex =>
val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
val height = math.sqrt(indexes.map { childIndex =>
KMeans.fastSquaredDistance(center, clusters(childIndex).center)
}.max)
val left = buildSubTree(leftIndex)
val right = buildSubTree(rightIndex)
new ClusteringTreeNode(index, size, center, cost, height, Array(left, right))
val children = indexes.map(buildSubTree(_)).toArray
new ClusteringTreeNode(index, size, center, cost, height, children)
} else {
val index = leafIndex
leafIndex += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ class BisectingKMeansSuite
final val k = 5
@transient var dataset: Dataset[_] = _

@transient var sparseDataset: Dataset[_] = _

override def beforeAll(): Unit = {
super.beforeAll()
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
sparseDataset = KMeansSuite.generateSparseData(spark, 100, 1000, k, 42)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the test really need to be this large? It takes ~1 sec which is long-ish for a unit test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can only repro the issue with very sparse data where number of columns is around 1k. I was able to reduce the number of rows to be only 10 however. I hope that is a small enough dataset.

}

test("default parameters") {
Expand All @@ -51,6 +54,23 @@ class BisectingKMeansSuite
assert(copiedModel.hasSummary)
}

test("SPARK-16473: Verify Bisecting K-Means does not fail in edge case where" +
"one cluster is empty after split") {
val bkm = new BisectingKMeans().setK(k).setMinDivisibleClusterSize(4).setMaxIter(4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please set the seed too since this test seems at risk of flakiness

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, added seed


assert(bkm.getK === k)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't bother testing getK, getMaxIter, or getMinDivisibleClusterSize
The setters should be tested elsewhere, so you may assume they work here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, removed

assert(bkm.getFeaturesCol === "features")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to test featuresCol, predictionCol, and other things which aren't relevant to this unit test. I'd simplify it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, removed

assert(bkm.getPredictionCol === "prediction")
assert(bkm.getMaxIter === 4)
assert(bkm.getMinDivisibleClusterSize === 4)
// Verify fit does not fail on very sparse data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me that this unit test actually tests the issue fixed in this PR. Is there a good way to see why it would? If not, then it would be great to write a tiny dataset by hand which would trigger the failure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this check to verify:
// Verify we hit the edge case
assert(numClusters < k && numClusters > 1)
the issue only occurs for very sparse data, but it occurs very consistently (almost all very sparse data that I generate can trigger the error)

val model = bkm.fit(sparseDataset)
assert(model.hasSummary)
val result = model.transform(sparseDataset)
val numClusters = result.select("prediction").distinct().collect().length
assert(numClusters <= k && numClusters >= 1)
}

test("setter/getter") {
val bkm = new BisectingKMeans()
.setK(9)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.clustering

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
Expand Down Expand Up @@ -160,6 +162,17 @@ object KMeansSuite {
spark.createDataFrame(rdd)
}

def generateSparseData(spark: SparkSession, rows: Int, dim: Int, k: Int, seed: Int): DataFrame = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k is never used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified the method to use it

val sc = spark.sparkContext
val random = new Random(seed)
val nnz = random.nextInt(dim)
val rdd = sc.parallelize(1 to rows)
.map(i => Vectors.sparse(dim, random.shuffle(0 to dim - 1).slice(0, nnz).sorted.toArray,
Array.fill(nnz)(random.nextDouble())))
.map(v => new TestRow(v))
spark.createDataFrame(rdd)
}

/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
Expand Down