Skip to content

Commit d4a0895

Browse files
smurakozimengxr
authored andcommitted
[SPARK-22884][ML] ML tests for StructuredStreaming: spark.ml.clustering
## What changes were proposed in this pull request? Converting clustering tests to also check code with structured streaming, using the ML testing infrastructure implemented in SPARK-22882. This PR is a new version of #20319 Author: Sandor Murakozi <[email protected]> Author: Joseph K. Bradley <[email protected]> Closes #21358 from jkbradley/smurakozi-SPARK-22884.
1 parent 439c695 commit d4a0895

File tree

4 files changed

+50
-65
lines changed

4 files changed

+50
-65
lines changed

mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,18 @@ package org.apache.spark.ml.clustering
1919

2020
import scala.language.existentials
2121

22-
import org.apache.spark.{SparkException, SparkFunSuite}
22+
import org.apache.spark.SparkException
2323
import org.apache.spark.ml.linalg.{Vector, Vectors}
2424
import org.apache.spark.ml.param.ParamMap
25-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
25+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
2626
import org.apache.spark.ml.util.TestingUtils._
2727
import org.apache.spark.mllib.clustering.DistanceMeasure
28-
import org.apache.spark.mllib.util.MLlibTestSparkContext
29-
import org.apache.spark.sql.{DataFrame, Dataset}
28+
import org.apache.spark.sql.Dataset
3029

31-
class BisectingKMeansSuite
32-
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
30+
31+
class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest {
32+
33+
import testImplicits._
3334

3435
final val k = 5
3536
@transient var dataset: Dataset[_] = _
@@ -68,10 +69,13 @@ class BisectingKMeansSuite
6869

6970
// Verify fit does not fail on very sparse data
7071
val model = bkm.fit(sparseDataset)
71-
val result = model.transform(sparseDataset)
72-
val numClusters = result.select("prediction").distinct().collect().length
73-
// Verify we hit the edge case
74-
assert(numClusters < k && numClusters > 1)
72+
73+
testTransformerByGlobalCheckFunc[Tuple1[Vector]](sparseDataset.toDF(), model, "prediction") {
74+
rows =>
75+
val numClusters = rows.distinct.length
76+
// Verify we hit the edge case
77+
assert(numClusters < k && numClusters > 1)
78+
}
7579
}
7680

7781
test("setter/getter") {
@@ -104,19 +108,16 @@ class BisectingKMeansSuite
104108
val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
105109
val model = bkm.fit(dataset)
106110
assert(model.clusterCenters.length === k)
107-
108-
val transformed = model.transform(dataset)
109-
val expectedColumns = Array("features", predictionColName)
110-
expectedColumns.foreach { column =>
111-
assert(transformed.columns.contains(column))
112-
}
113-
val clusters =
114-
transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet
115-
assert(clusters.size === k)
116-
assert(clusters === Set(0, 1, 2, 3, 4))
117111
assert(model.computeCost(dataset) < 0.1)
118112
assert(model.hasParent)
119113

114+
testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model,
115+
"features", predictionColName) { rows =>
116+
val clusters = rows.map(_.getAs[Int](predictionColName)).toSet
117+
assert(clusters.size === k)
118+
assert(clusters === Set(0, 1, 2, 3, 4))
119+
}
120+
120121
// Check validity of model summary
121122
val numRows = dataset.count()
122123
assert(model.hasSummary)

mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,15 @@ import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors}
2424
import org.apache.spark.ml.param.ParamMap
2525
import org.apache.spark.ml.stat.distribution.MultivariateGaussian
26-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
26+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
2727
import org.apache.spark.ml.util.TestingUtils._
28-
import org.apache.spark.mllib.util.MLlibTestSparkContext
29-
import org.apache.spark.sql.{DataFrame, Dataset, Row}
28+
import org.apache.spark.sql.{Dataset, Row}
3029

31-
class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
32-
with DefaultReadWriteTest {
3330

34-
import testImplicits._
31+
class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest {
32+
3533
import GaussianMixtureSuite._
34+
import testImplicits._
3635

3736
final val k = 5
3837
private val seed = 538009335
@@ -119,15 +118,10 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
119118
assert(model.weights.length === k)
120119
assert(model.gaussians.length === k)
121120

122-
val transformed = model.transform(dataset)
123-
val expectedColumns = Array("features", predictionColName, probabilityColName)
124-
expectedColumns.foreach { column =>
125-
assert(transformed.columns.contains(column))
126-
}
127-
128121
// Check prediction matches the highest probability, and probabilities sum to one.
129-
transformed.select(predictionColName, probabilityColName).collect().foreach {
130-
case Row(pred: Int, prob: Vector) =>
122+
testTransformer[Tuple1[Vector]](dataset.toDF(), model,
123+
"features", predictionColName, probabilityColName) {
124+
case Row(_, pred: Int, prob: Vector) =>
131125
val probArray = prob.toArray
132126
val predFromProb = probArray.zipWithIndex.maxBy(_._1)._2
133127
assert(pred === predFromProb)

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,21 @@ import scala.util.Random
2222

2323
import org.dmg.pmml.{ClusteringModel, PMML}
2424

25-
import org.apache.spark.{SparkException, SparkFunSuite}
25+
import org.apache.spark.SparkException
2626
import org.apache.spark.ml.linalg.{Vector, Vectors}
2727
import org.apache.spark.ml.param.ParamMap
28-
import org.apache.spark.ml.util._
28+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils, PMMLReadWriteTest}
2929
import org.apache.spark.ml.util.TestingUtils._
30-
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
30+
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans,
31+
KMeansModel => MLlibKMeansModel}
3132
import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors}
32-
import org.apache.spark.mllib.util.MLlibTestSparkContext
3333
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
3434

3535
private[clustering] case class TestRow(features: Vector)
3636

37-
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
38-
with PMMLReadWriteTest {
37+
class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest {
38+
39+
import testImplicits._
3940

4041
final val k = 5
4142
@transient var dataset: Dataset[_] = _
@@ -109,15 +110,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
109110
val model = kmeans.fit(dataset)
110111
assert(model.clusterCenters.length === k)
111112

112-
val transformed = model.transform(dataset)
113-
val expectedColumns = Array("features", predictionColName)
114-
expectedColumns.foreach { column =>
115-
assert(transformed.columns.contains(column))
113+
testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model,
114+
"features", predictionColName) { rows =>
115+
val clusters = rows.map(_.getAs[Int](predictionColName)).toSet
116+
assert(clusters.size === k)
117+
assert(clusters === Set(0, 1, 2, 3, 4))
116118
}
117-
val clusters =
118-
transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet
119-
assert(clusters.size === k)
120-
assert(clusters === Set(0, 1, 2, 3, 4))
119+
121120
assert(model.computeCost(dataset) < 0.1)
122121
assert(model.hasParent)
123122

@@ -149,9 +148,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
149148
model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName)
150149

151150
val transformed = model.transform(dataset.withColumnRenamed("features", featuresColName))
152-
Seq(featuresColName, predictionColName).foreach { column =>
153-
assert(transformed.columns.contains(column))
154-
}
151+
assert(transformed.schema.fieldNames.toSet === Set(featuresColName, predictionColName))
155152
assert(model.getFeaturesCol == featuresColName)
156153
assert(model.getPredictionCol == predictionColName)
157154
}

mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,9 @@ import scala.language.existentials
2121

2222
import org.apache.hadoop.fs.Path
2323

24-
import org.apache.spark.SparkFunSuite
2524
import org.apache.spark.ml.linalg.{Vector, Vectors}
26-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
25+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
2726
import org.apache.spark.ml.util.TestingUtils._
28-
import org.apache.spark.mllib.util.MLlibTestSparkContext
2927
import org.apache.spark.sql._
3028

3129
object LDASuite {
@@ -61,7 +59,7 @@ object LDASuite {
6159
}
6260

6361

64-
class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
62+
class LDASuite extends MLTest with DefaultReadWriteTest {
6563

6664
import testImplicits._
6765

@@ -186,16 +184,11 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
186184
assert(model.topicsMatrix.numCols === k)
187185
assert(!model.isDistributed)
188186

189-
// transform()
190-
val transformed = model.transform(dataset)
191-
val expectedColumns = Array("features", lda.getTopicDistributionCol)
192-
expectedColumns.foreach { column =>
193-
assert(transformed.columns.contains(column))
194-
}
195-
transformed.select(lda.getTopicDistributionCol).collect().foreach { r =>
196-
val topicDistribution = r.getAs[Vector](0)
197-
assert(topicDistribution.size === k)
198-
assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0))
187+
testTransformer[Tuple1[Vector]](dataset.toDF(), model,
188+
"features", lda.getTopicDistributionCol) {
189+
case Row(_, topicDistribution: Vector) =>
190+
assert(topicDistribution.size === k)
191+
assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0))
199192
}
200193

201194
// logLikelihood, logPerplexity

0 commit comments

Comments
 (0)