Skip to content

Commit fe409f3

Browse files
imatiach-msftjkbradley
authored andcommitted
[SPARK-14975][ML] Fixed GBTClassifier to predict probability per training instance and fixed interfaces
## What changes were proposed in this pull request? For all of the classifiers in MLLib we can predict probabilities except for GBTClassifier. Also, all classifiers inherit from ProbabilisticClassifier but GBTClassifier strangely inherits from Predictor, which is a bug. This change corrects the interface and adds the ability for the classifier to give a probabilities vector. ## How was this patch tested? The basic ML tests were run after making the changes. I've marked this as WIP as I need to add more tests. Author: Ilya Matiach <[email protected]> Closes #16441 from imatiach-msft/ilmat/fix-GBT.
1 parent a81e336 commit fe409f3

File tree

5 files changed

+248
-29
lines changed

5 files changed

+248
-29
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 74 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@ import org.json4s.JsonDSL._
2323

2424
import org.apache.spark.annotation.Since
2525
import org.apache.spark.internal.Logging
26-
import org.apache.spark.ml.{PredictionModel, Predictor}
2726
import org.apache.spark.ml.feature.LabeledPoint
28-
import org.apache.spark.ml.linalg.Vector
27+
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
2928
import org.apache.spark.ml.param.ParamMap
3029
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
3130
import org.apache.spark.ml.tree._
3231
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
3332
import org.apache.spark.ml.util._
3433
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
3534
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
35+
import org.apache.spark.mllib.tree.loss.LogLoss
3636
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
3737
import org.apache.spark.rdd.RDD
3838
import org.apache.spark.sql.{DataFrame, Dataset, Row}
@@ -58,7 +58,7 @@ import org.apache.spark.sql.functions._
5858
@Since("1.4.0")
5959
class GBTClassifier @Since("1.4.0") (
6060
@Since("1.4.0") override val uid: String)
61-
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
61+
extends ProbabilisticClassifier[Vector, GBTClassifier, GBTClassificationModel]
6262
with GBTClassifierParams with DefaultParamsWritable with Logging {
6363

6464
@Since("1.4.0")
@@ -158,12 +158,19 @@ class GBTClassifier @Since("1.4.0") (
158158
val numFeatures = oldDataset.first().features.size
159159
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
160160

161+
val numClasses = 2
162+
if (isDefined(thresholds)) {
163+
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
164+
".train() called with non-matching numClasses and thresholds.length." +
165+
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
166+
}
167+
161168
val instr = Instrumentation.create(this, oldDataset)
162169
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
163170
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
164171
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
165172
instr.logNumFeatures(numFeatures)
166-
instr.logNumClasses(2)
173+
instr.logNumClasses(numClasses)
167174

168175
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
169176
$(seed))
@@ -202,15 +209,30 @@ class GBTClassificationModel private[ml](
202209
@Since("1.6.0") override val uid: String,
203210
private val _trees: Array[DecisionTreeRegressionModel],
204211
private val _treeWeights: Array[Double],
205-
@Since("1.6.0") override val numFeatures: Int)
206-
extends PredictionModel[Vector, GBTClassificationModel]
212+
@Since("1.6.0") override val numFeatures: Int,
213+
@Since("2.2.0") override val numClasses: Int)
214+
extends ProbabilisticClassificationModel[Vector, GBTClassificationModel]
207215
with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel]
208216
with MLWritable with Serializable {
209217

210218
require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
211219
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
212220
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
213221

222+
/**
223+
* Construct a GBTClassificationModel
224+
*
225+
* @param _trees Decision trees in the ensemble.
226+
* @param _treeWeights Weights for the decision trees in the ensemble.
227+
* @param numFeatures The number of features.
228+
*/
229+
private[ml] def this(
230+
uid: String,
231+
_trees: Array[DecisionTreeRegressionModel],
232+
_treeWeights: Array[Double],
233+
numFeatures: Int) =
234+
this(uid, _trees, _treeWeights, numFeatures, 2)
235+
214236
/**
215237
* Construct a GBTClassificationModel
216238
*
@@ -219,7 +241,7 @@ class GBTClassificationModel private[ml](
219241
*/
220242
@Since("1.6.0")
221243
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
222-
this(uid, _trees, _treeWeights, -1)
244+
this(uid, _trees, _treeWeights, -1, 2)
223245

224246
@Since("1.4.0")
225247
override def trees: Array[DecisionTreeRegressionModel] = _trees
@@ -242,19 +264,37 @@ class GBTClassificationModel private[ml](
242264
}
243265

244266
override protected def predict(features: Vector): Double = {
245-
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
246-
// Classifies by thresholding sum of weighted tree predictions
247-
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
248-
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
249-
if (prediction > 0.0) 1.0 else 0.0
267+
// If thresholds defined, use predictRaw to get probabilities, otherwise use optimization
268+
if (isDefined(thresholds)) {
269+
super.predict(features)
270+
} else {
271+
if (margin(features) > 0.0) 1.0 else 0.0
272+
}
273+
}
274+
275+
override protected def predictRaw(features: Vector): Vector = {
276+
val prediction: Double = margin(features)
277+
Vectors.dense(Array(-prediction, prediction))
278+
}
279+
280+
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
281+
rawPrediction match {
282+
case dv: DenseVector =>
283+
dv.values(0) = loss.computeProbability(dv.values(0))
284+
dv.values(1) = 1.0 - dv.values(0)
285+
dv
286+
case sv: SparseVector =>
287+
throw new RuntimeException("Unexpected error in GBTClassificationModel:" +
288+
" raw2probabilityInPlace encountered SparseVector")
289+
}
250290
}
251291

252292
/** Number of trees in ensemble */
253293
val numTrees: Int = trees.length
254294

255295
@Since("1.4.0")
256296
override def copy(extra: ParamMap): GBTClassificationModel = {
257-
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
297+
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses),
258298
extra).setParent(parent)
259299
}
260300

@@ -276,18 +316,30 @@ class GBTClassificationModel private[ml](
276316
@Since("2.0.0")
277317
lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
278318

319+
/** Raw prediction for the positive class. */
320+
private def margin(features: Vector): Double = {
321+
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
322+
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
323+
}
324+
279325
/** (private[ml]) Convert to a model in the old API */
280326
private[ml] def toOld: OldGBTModel = {
281327
new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
282328
}
283329

330+
// hard coded loss, which is not meant to be changed in the model
331+
private val loss = getOldLossType
332+
284333
@Since("2.0.0")
285334
override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
286335
}
287336

288337
@Since("2.0.0")
289338
object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
290339

340+
private val numFeaturesKey: String = "numFeatures"
341+
private val numTreesKey: String = "numTrees"
342+
291343
@Since("2.0.0")
292344
override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader
293345

@@ -300,8 +352,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
300352
override protected def saveImpl(path: String): Unit = {
301353

302354
val extraMetadata: JObject = Map(
303-
"numFeatures" -> instance.numFeatures,
304-
"numTrees" -> instance.getNumTrees)
355+
numFeaturesKey -> instance.numFeatures,
356+
numTreesKey -> instance.getNumTrees)
305357
EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
306358
}
307359
}
@@ -316,8 +368,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
316368
implicit val format = DefaultFormats
317369
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
318370
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
319-
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
320-
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
371+
val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
372+
val numTrees = (metadata.metadata \ numTreesKey).extract[Int]
321373

322374
val trees: Array[DecisionTreeRegressionModel] = treesData.map {
323375
case (treeMetadata, root) =>
@@ -328,7 +380,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
328380
}
329381
require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" +
330382
s" trees based on metadata but found ${trees.length} trees.")
331-
val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures)
383+
val model = new GBTClassificationModel(metadata.uid,
384+
trees, treeWeights, numFeatures)
332385
DefaultParamsReader.getAndSetParams(model, metadata)
333386
model
334387
}
@@ -339,14 +392,15 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
339392
oldModel: OldGBTModel,
340393
parent: GBTClassifier,
341394
categoricalFeatures: Map[Int, Int],
342-
numFeatures: Int = -1): GBTClassificationModel = {
395+
numFeatures: Int = -1,
396+
numClasses: Int = 2): GBTClassificationModel = {
343397
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
344398
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
345399
val newTrees = oldModel.trees.map { tree =>
346400
// parent for each tree is null since there is no good way to set this.
347401
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
348402
}
349403
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
350-
new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures)
404+
new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures, numClasses)
351405
}
352406
}

mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.ml.param.shared._
2525
import org.apache.spark.ml.util.SchemaUtils
2626
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
2727
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
28-
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
28+
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, ClassificationLoss => OldClassificationLoss, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
2929
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
3030

3131
/**
@@ -531,7 +531,7 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam
531531
def getLossType: String = $(lossType).toLowerCase
532532

533533
/** (private[ml]) Convert new loss to old loss. */
534-
override private[ml] def getOldLossType: OldLoss = {
534+
override private[ml] def getOldLossType: OldClassificationLoss = {
535535
getLossType match {
536536
case "logistic" => OldLogLoss
537537
case _ =>

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.mllib.tree.loss
2020
import org.apache.spark.annotation.{DeveloperApi, Since}
2121
import org.apache.spark.mllib.util.MLUtils
2222

23-
2423
/**
2524
* :: DeveloperApi ::
2625
* Class for log loss calculation (for classification).
@@ -32,7 +31,7 @@ import org.apache.spark.mllib.util.MLUtils
3231
*/
3332
@Since("1.2.0")
3433
@DeveloperApi
35-
object LogLoss extends Loss {
34+
object LogLoss extends ClassificationLoss {
3635

3736
/**
3837
* Method to calculate the loss gradients for the gradient boosting calculation for binary
@@ -52,4 +51,11 @@ object LogLoss extends Loss {
5251
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
5352
2.0 * MLUtils.log1pExp(-margin)
5453
}
54+
55+
/**
56+
* Returns the estimated probability of a label of 1.0.
57+
*/
58+
override private[spark] def computeProbability(margin: Double): Double = {
59+
1.0 / (1.0 + math.exp(-2.0 * margin))
60+
}
5561
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import org.apache.spark.mllib.regression.LabeledPoint
2222
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
2323
import org.apache.spark.rdd.RDD
2424

25-
2625
/**
2726
* :: DeveloperApi ::
2827
* Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
@@ -67,3 +66,10 @@ trait Loss extends Serializable {
6766
*/
6867
private[spark] def computeError(prediction: Double, label: Double): Double
6968
}
69+
70+
private[spark] trait ClassificationLoss extends Loss {
71+
/**
72+
* Computes the class probability given the margin.
73+
*/
74+
private[spark] def computeProbability(margin: Double): Double
75+
}

0 commit comments

Comments
 (0)