@@ -26,22 +26,20 @@ import org.apache.spark.ml.param.ParamsSuite
2626import org .apache .spark .ml .regression .DecisionTreeRegressionModel
2727import org .apache .spark .ml .tree .LeafNode
2828import org .apache .spark .ml .tree .impl .TreeTests
29- import org .apache .spark .ml .util .{DefaultReadWriteTest , MLTestingUtils }
29+ import org .apache .spark .ml .util .{DefaultReadWriteTest , MLTest , MLTestingUtils }
3030import org .apache .spark .ml .util .TestingUtils ._
3131import org .apache .spark .mllib .regression .{LabeledPoint => OldLabeledPoint }
3232import org .apache .spark .mllib .tree .{EnsembleTestHelper , GradientBoostedTrees => OldGBT }
3333import org .apache .spark .mllib .tree .configuration .{Algo => OldAlgo }
3434import org .apache .spark .mllib .tree .loss .LogLoss
35- import org .apache .spark .mllib .util .MLlibTestSparkContext
3635import org .apache .spark .rdd .RDD
3736import org .apache .spark .sql .{DataFrame , Row }
3837import org .apache .spark .util .Utils
3938
4039/**
4140 * Test suite for [[GBTClassifier ]].
4241 */
43- class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
44- with DefaultReadWriteTest {
42+ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
4543
4644 import testImplicits ._
4745 import GBTClassifierSuite .compareAPIs
@@ -126,30 +124,34 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
126124
127125 // should predict all zeros
128126 binaryModel.setThresholds(Array (0.0 , 1.0 ))
129- val binaryZeroPredictions = binaryModel.transform(df).select(" prediction" ).collect()
130- assert(binaryZeroPredictions.forall(_.getDouble(0 ) === 0.0 ))
127+ testTransformer[(Double , Vector )](df, binaryModel, " prediction" ) {
128+ case Row (prediction : Double ) => prediction === 0.0
129+ }
131130
132131 // should predict all ones
133132 binaryModel.setThresholds(Array (1.0 , 0.0 ))
134- val binaryOnePredictions = binaryModel.transform(df).select( " prediction" ).collect()
135- assert(binaryOnePredictions.forall(_.getDouble( 0 ) === 1.0 ))
136-
133+ testTransformer[( Double , Vector )](df, binaryModel, " prediction" ) {
134+ case Row ( prediction : Double ) => prediction === 1.0
135+ }
137136
138137 val gbtBase = new GBTClassifier
139138 val model = gbtBase.fit(df)
140139 val basePredictions = model.transform(df).select(" prediction" ).collect()
141140
142141 // constant threshold scaling is the same as no thresholds
143142 binaryModel.setThresholds(Array (1.0 , 1.0 ))
144- val scaledPredictions = binaryModel.transform(df).select(" prediction" ).collect()
145- assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
146- scaled.getDouble(0 ) === base.getDouble(0 )
147- })
143+ testTransformerByGlobalCheckFunc[(Double , Vector )](df, binaryModel, " prediction" ) {
144+ scaledPredictions : Seq [Row ] =>
145+ assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
146+ scaled.getDouble(0 ) === base.getDouble(0 )
147+ })
148+ }
148149
149150 // force it to use the predict method
150151 model.setRawPredictionCol(" " ).setProbabilityCol(" " ).setThresholds(Array (0 , 1 ))
151- val predictionsWithPredict = model.transform(df).select(" prediction" ).collect()
152- assert(predictionsWithPredict.forall(_.getDouble(0 ) === 0.0 ))
152+ testTransformer[(Double , Vector )](df, model, " prediction" ) {
153+ case Row (prediction : Double ) => prediction === 0.0
154+ }
153155 }
154156
155157 test(" GBTClassifier: Predictor, Classifier methods" ) {
@@ -169,61 +171,30 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
169171 val blas = BLAS .getInstance()
170172
171173 val validationDataset = validationData.toDF(labelCol, featuresCol)
172- val results = gbtModel.transform(validationDataset)
173- // check that raw prediction is tree predictions dot tree weights
174- results.select(rawPredictionCol, featuresCol).collect().foreach {
175- case Row (raw : Vector , features : Vector ) =>
174+ testTransformer[(Double , Vector )](validationDataset, gbtModel,
175+ " rawPrediction" , " features" , " probability" , " prediction" ) {
176+ case Row (raw : Vector , features : Vector , prob : Vector , pred : Double ) =>
176177 assert(raw.size === 2 )
178+ // check that raw prediction is tree predictions dot tree weights
177179 val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction)
178180 val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1 , gbtModel.treeWeights, 1 )
179181 assert(raw ~== Vectors .dense(- prediction, prediction) relTol eps)
180- }
181182
182- // Compare rawPrediction with probability
183- results.select(rawPredictionCol, probabilityCol).collect().foreach {
184- case Row (raw : Vector , prob : Vector ) =>
185- assert(raw.size === 2 )
183+ // Compare rawPrediction with probability
186184 assert(prob.size === 2 )
187185 // Note: we should check other loss types for classification if they are added
188186 val predFromRaw = raw.toDense.values.map(value => LogLoss .computeProbability(value))
189187 assert(prob(0 ) ~== predFromRaw(0 ) relTol eps)
190188 assert(prob(1 ) ~== predFromRaw(1 ) relTol eps)
191189 assert(prob(0 ) + prob(1 ) ~== 1.0 absTol absEps)
192- }
193190
194- // Compare prediction with probability
195- results.select(predictionCol, probabilityCol).collect().foreach {
196- case Row (pred : Double , prob : Vector ) =>
191+ // Compare prediction with probability
197192 val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
198193 assert(pred == predFromProb)
199194 }
200195
201- // force it to use raw2prediction
202- gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol(" " )
203- val resultsUsingRaw2Predict =
204- gbtModel.transform(validationDataset).select(predictionCol).as[Double ].collect()
205- resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double ].collect()).foreach {
206- case (pred1, pred2) => assert(pred1 === pred2)
207- }
208-
209- // force it to use probability2prediction
210- gbtModel.setRawPredictionCol(" " ).setProbabilityCol(probabilityCol)
211- val resultsUsingProb2Predict =
212- gbtModel.transform(validationDataset).select(predictionCol).as[Double ].collect()
213- resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double ].collect()).foreach {
214- case (pred1, pred2) => assert(pred1 === pred2)
215- }
216-
217- // force it to use predict
218- gbtModel.setRawPredictionCol(" " ).setProbabilityCol(" " )
219- val resultsUsingPredict =
220- gbtModel.transform(validationDataset).select(predictionCol).as[Double ].collect()
221- resultsUsingPredict.zip(results.select(predictionCol).as[Double ].collect()).foreach {
222- case (pred1, pred2) => assert(pred1 === pred2)
223- }
224-
225196 ProbabilisticClassifierSuite .testPredictMethods[
226- Vector , GBTClassificationModel ](gbtModel, validationDataset)
197+ Vector , GBTClassificationModel ](this , gbtModel, validationDataset)
227198 }
228199
229200 test(" GBT parameter stepSize should be in interval (0, 1]" ) {
0 commit comments