Skip to content

Commit 4550673

Browse files
WeichenXu123jkbradley
authored andcommitted
[SPARK-22882][ML][TESTS] ML test for structured streaming: ml.classification
## What changes were proposed in this pull request? adding Structured Streaming tests for all Models/Transformers in spark.ml.classification ## How was this patch tested? N/A Author: WeichenXu <[email protected]> Closes #20121 from WeichenXu123/ml_stream_test_classification. (cherry picked from commit 98a5c0a) Signed-off-by: Joseph K. Bradley <[email protected]>
1 parent 232b9f8 commit 4550673

File tree

9 files changed

+202
-305
lines changed

9 files changed

+202
-305
lines changed

mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,14 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
2323
import org.apache.spark.ml.param.ParamsSuite
2424
import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
2525
import org.apache.spark.ml.tree.impl.TreeTests
26-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
26+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
2727
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
28-
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
29-
import org.apache.spark.mllib.util.MLlibTestSparkContext
28+
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
29+
DecisionTreeSuite => OldDecisionTreeSuite}
3030
import org.apache.spark.rdd.RDD
3131
import org.apache.spark.sql.{DataFrame, Row}
3232

33-
class DecisionTreeClassifierSuite
34-
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
33+
class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
3534

3635
import DecisionTreeClassifierSuite.compareAPIs
3736
import testImplicits._
@@ -251,20 +250,18 @@ class DecisionTreeClassifierSuite
251250

252251
MLTestingUtils.checkCopyAndUids(dt, newTree)
253252

254-
val predictions = newTree.transform(newData)
255-
.select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)
256-
.collect()
257-
258-
predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
259-
assert(pred === rawPred.argmax,
260-
s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
261-
val sum = rawPred.toArray.sum
262-
assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
263-
"probability prediction mismatch")
253+
testTransformer[(Vector, Double)](newData, newTree,
254+
"prediction", "rawPrediction", "probability") {
255+
case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
256+
assert(pred === rawPred.argmax,
257+
s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
258+
val sum = rawPred.toArray.sum
259+
assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
260+
"probability prediction mismatch")
264261
}
265262

266263
ProbabilisticClassifierSuite.testPredictMethods[
267-
Vector, DecisionTreeClassificationModel](newTree, newData)
264+
Vector, DecisionTreeClassificationModel](this, newTree, newData)
268265
}
269266

270267
test("training with 1-category categorical feature") {

mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala

Lines changed: 24 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,20 @@ import org.apache.spark.ml.param.ParamsSuite
2626
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
2727
import org.apache.spark.ml.tree.LeafNode
2828
import org.apache.spark.ml.tree.impl.TreeTests
29-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
29+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
3030
import org.apache.spark.ml.util.TestingUtils._
3131
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
3232
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
3333
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
3434
import org.apache.spark.mllib.tree.loss.LogLoss
35-
import org.apache.spark.mllib.util.MLlibTestSparkContext
3635
import org.apache.spark.rdd.RDD
3736
import org.apache.spark.sql.{DataFrame, Row}
3837
import org.apache.spark.util.Utils
3938

4039
/**
4140
* Test suite for [[GBTClassifier]].
4241
*/
43-
class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
44-
with DefaultReadWriteTest {
42+
class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
4543

4644
import testImplicits._
4745
import GBTClassifierSuite.compareAPIs
@@ -126,30 +124,34 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
126124

127125
// should predict all zeros
128126
binaryModel.setThresholds(Array(0.0, 1.0))
129-
val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect()
130-
assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0))
127+
testTransformer[(Double, Vector)](df, binaryModel, "prediction") {
128+
case Row(prediction: Double) => prediction === 0.0
129+
}
131130

132131
// should predict all ones
133132
binaryModel.setThresholds(Array(1.0, 0.0))
134-
val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect()
135-
assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0))
136-
133+
testTransformer[(Double, Vector)](df, binaryModel, "prediction") {
134+
case Row(prediction: Double) => prediction === 1.0
135+
}
137136

138137
val gbtBase = new GBTClassifier
139138
val model = gbtBase.fit(df)
140139
val basePredictions = model.transform(df).select("prediction").collect()
141140

142141
// constant threshold scaling is the same as no thresholds
143142
binaryModel.setThresholds(Array(1.0, 1.0))
144-
val scaledPredictions = binaryModel.transform(df).select("prediction").collect()
145-
assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
146-
scaled.getDouble(0) === base.getDouble(0)
147-
})
143+
testTransformerByGlobalCheckFunc[(Double, Vector)](df, binaryModel, "prediction") {
144+
scaledPredictions: Seq[Row] =>
145+
assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
146+
scaled.getDouble(0) === base.getDouble(0)
147+
})
148+
}
148149

149150
// force it to use the predict method
150151
model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1))
151-
val predictionsWithPredict = model.transform(df).select("prediction").collect()
152-
assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0))
152+
testTransformer[(Double, Vector)](df, model, "prediction") {
153+
case Row(prediction: Double) => prediction === 0.0
154+
}
153155
}
154156

155157
test("GBTClassifier: Predictor, Classifier methods") {
@@ -169,61 +171,30 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
169171
val blas = BLAS.getInstance()
170172

171173
val validationDataset = validationData.toDF(labelCol, featuresCol)
172-
val results = gbtModel.transform(validationDataset)
173-
// check that raw prediction is tree predictions dot tree weights
174-
results.select(rawPredictionCol, featuresCol).collect().foreach {
175-
case Row(raw: Vector, features: Vector) =>
174+
testTransformer[(Double, Vector)](validationDataset, gbtModel,
175+
"rawPrediction", "features", "probability", "prediction") {
176+
case Row(raw: Vector, features: Vector, prob: Vector, pred: Double) =>
176177
assert(raw.size === 2)
178+
// check that raw prediction is tree predictions dot tree weights
177179
val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction)
178180
val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1)
179181
assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps)
180-
}
181182

182-
// Compare rawPrediction with probability
183-
results.select(rawPredictionCol, probabilityCol).collect().foreach {
184-
case Row(raw: Vector, prob: Vector) =>
185-
assert(raw.size === 2)
183+
// Compare rawPrediction with probability
186184
assert(prob.size === 2)
187185
// Note: we should check other loss types for classification if they are added
188186
val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value))
189187
assert(prob(0) ~== predFromRaw(0) relTol eps)
190188
assert(prob(1) ~== predFromRaw(1) relTol eps)
191189
assert(prob(0) + prob(1) ~== 1.0 absTol absEps)
192-
}
193190

194-
// Compare prediction with probability
195-
results.select(predictionCol, probabilityCol).collect().foreach {
196-
case Row(pred: Double, prob: Vector) =>
191+
// Compare prediction with probability
197192
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
198193
assert(pred == predFromProb)
199194
}
200195

201-
// force it to use raw2prediction
202-
gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("")
203-
val resultsUsingRaw2Predict =
204-
gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
205-
resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
206-
case (pred1, pred2) => assert(pred1 === pred2)
207-
}
208-
209-
// force it to use probability2prediction
210-
gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol)
211-
val resultsUsingProb2Predict =
212-
gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
213-
resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
214-
case (pred1, pred2) => assert(pred1 === pred2)
215-
}
216-
217-
// force it to use predict
218-
gbtModel.setRawPredictionCol("").setProbabilityCol("")
219-
val resultsUsingPredict =
220-
gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
221-
resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach {
222-
case (pred1, pred2) => assert(pred1 === pred2)
223-
}
224-
225196
ProbabilisticClassifierSuite.testPredictMethods[
226-
Vector, GBTClassificationModel](gbtModel, validationDataset)
197+
Vector, GBTClassificationModel](this, gbtModel, validationDataset)
227198
}
228199

229200
test("GBT parameter stepSize should be in interval (0, 1]") {

mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,18 @@ import scala.util.Random
2121

2222
import breeze.linalg.{DenseVector => BDV}
2323

24-
import org.apache.spark.SparkFunSuite
2524
import org.apache.spark.ml.classification.LinearSVCSuite._
2625
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
2726
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
2827
import org.apache.spark.ml.optim.aggregator.HingeAggregator
2928
import org.apache.spark.ml.param.ParamsSuite
30-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
29+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
3130
import org.apache.spark.ml.util.TestingUtils._
32-
import org.apache.spark.mllib.util.MLlibTestSparkContext
3331
import org.apache.spark.sql.{Dataset, Row}
3432
import org.apache.spark.sql.functions.udf
3533

3634

37-
class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
35+
class LinearSVCSuite extends MLTest with DefaultReadWriteTest {
3836

3937
import testImplicits._
4038

@@ -141,10 +139,11 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
141139
threshold: Double,
142140
expected: Set[(Int, Double)]): Unit = {
143141
model.setThreshold(threshold)
144-
val results = model.transform(df).select("id", "prediction").collect()
145-
.map(r => (r.getInt(0), r.getDouble(1)))
146-
.toSet
147-
assert(results === expected, s"Failed for threshold = $threshold")
142+
testTransformerByGlobalCheckFunc[(Int, Vector)](df, model, "id", "prediction") {
143+
rows: Seq[Row] =>
144+
val results = rows.map(r => (r.getInt(0), r.getDouble(1))).toSet
145+
assert(results === expected, s"Failed for threshold = $threshold")
146+
}
148147
}
149148

150149
def checkResults(threshold: Double, expected: Set[(Int, Double)]): Unit = {

0 commit comments

Comments
 (0)