Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,15 @@ private object BisectingKMeans extends Serializable {
assignments.map { case (index, v) =>
if (divisibleIndices.contains(index)) {
val children = Seq(leftChildIndex(index), rightChildIndex(index))
val selected = children.minBy { child =>
KMeans.fastSquaredDistance(newClusterCenters(child), v)
val newClusterChildren = children.filter(newClusterCenters.contains(_))
if (newClusterChildren.nonEmpty) {
val selected = newClusterChildren.minBy { child =>
KMeans.fastSquaredDistance(newClusterCenters(child), v)
}
(selected, v)
} else {
(index, v)
}
(selected, v)
} else {
(index, v)
}
Expand Down Expand Up @@ -372,12 +377,12 @@ private object BisectingKMeans extends Serializable {
internalIndex -= 1
val leftIndex = leftChildIndex(rawIndex)
val rightIndex = rightChildIndex(rawIndex)
val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex =>
val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
val height = math.sqrt(indexes.map { childIndex =>
KMeans.fastSquaredDistance(center, clusters(childIndex).center)
}.max)
val left = buildSubTree(leftIndex)
val right = buildSubTree(rightIndex)
new ClusteringTreeNode(index, size, center, cost, height, Array(left, right))
val children = indexes.map(buildSubTree(_)).toArray
new ClusteringTreeNode(index, size, center, cost, height, children)
} else {
val index = leafIndex
leafIndex += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ class BisectingKMeansSuite
final val k = 5
@transient var dataset: Dataset[_] = _

@transient var sparseDataset: Dataset[_] = _

override def beforeAll(): Unit = {
super.beforeAll()
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
sparseDataset = KMeansSuite.generateSparseData(spark, 10, 1000, 42)
}

test("default parameters") {
Expand All @@ -51,6 +54,18 @@ class BisectingKMeansSuite
assert(copiedModel.hasSummary)
}

test("SPARK-16473: Verify Bisecting K-Means does not fail in edge case where" +
"one cluster is empty after split") {
val bkm = new BisectingKMeans().setK(k).setMinDivisibleClusterSize(4).setMaxIter(4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please set the seed too since this test seems at risk of flakiness

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, added seed


// Verify fit does not fail on very sparse data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me that this unit test actually tests the issue fixed in this PR. Is there a good way to see why it would? If not, then it would be great to write a tiny dataset by hand which would trigger the failure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this check to verify:
// Verify we hit the edge case
assert(numClusters < k && numClusters > 1)
the issue only occurs for very sparse data, but it occurs very consistently (almost all very sparse data that I generate can trigger the error)

val model = bkm.fit(sparseDataset)
val result = model.transform(sparseDataset)
val numClusters = result.select("prediction").distinct().collect().length
// Verify we hit the edge case
assert(numClusters < k && numClusters > 1)
}

test("setter/getter") {
val bkm = new BisectingKMeans()
.setK(9)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.clustering

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
Expand Down Expand Up @@ -160,6 +162,17 @@ object KMeansSuite {
spark.createDataFrame(rdd)
}

def generateSparseData(spark: SparkSession, rows: Int, dim: Int, seed: Int): DataFrame = {
val sc = spark.sparkContext
val random = new Random(seed)
val nnz = random.nextInt(dim)
val rdd = sc.parallelize(1 to rows)
.map(i => Vectors.sparse(dim, random.shuffle(0 to dim - 1).slice(0, nnz).sorted.toArray,
Array.fill(nnz)(random.nextDouble())))
.map(v => new TestRow(v))
spark.createDataFrame(rdd)
}

/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
Expand Down