@@ -23,16 +23,16 @@ import org.json4s.JsonDSL._
2323
2424import org .apache .spark .annotation .Since
2525import org .apache .spark .internal .Logging
26- import org .apache .spark .ml .{PredictionModel , Predictor }
2726import org .apache .spark .ml .feature .LabeledPoint
28- import org .apache .spark .ml .linalg .Vector
27+ import org .apache .spark .ml .linalg .{ DenseVector , SparseVector , Vector , Vectors }
2928import org .apache .spark .ml .param .ParamMap
3029import org .apache .spark .ml .regression .DecisionTreeRegressionModel
3130import org .apache .spark .ml .tree ._
3231import org .apache .spark .ml .tree .impl .GradientBoostedTrees
3332import org .apache .spark .ml .util ._
3433import org .apache .spark .ml .util .DefaultParamsReader .Metadata
3534import org .apache .spark .mllib .tree .configuration .{Algo => OldAlgo }
35+ import org .apache .spark .mllib .tree .loss .LogLoss
3636import org .apache .spark .mllib .tree .model .{GradientBoostedTreesModel => OldGBTModel }
3737import org .apache .spark .rdd .RDD
3838import org .apache .spark .sql .{DataFrame , Dataset , Row }
@@ -58,7 +58,7 @@ import org.apache.spark.sql.functions._
5858@ Since (" 1.4.0" )
5959class GBTClassifier @ Since (" 1.4.0" ) (
6060 @ Since (" 1.4.0" ) override val uid : String )
61- extends Predictor [Vector , GBTClassifier , GBTClassificationModel ]
61+ extends ProbabilisticClassifier [Vector , GBTClassifier , GBTClassificationModel ]
6262 with GBTClassifierParams with DefaultParamsWritable with Logging {
6363
6464 @ Since (" 1.4.0" )
@@ -158,12 +158,19 @@ class GBTClassifier @Since("1.4.0") (
158158 val numFeatures = oldDataset.first().features.size
159159 val boostingStrategy = super .getOldBoostingStrategy(categoricalFeatures, OldAlgo .Classification )
160160
161+ val numClasses = 2
162+ if (isDefined(thresholds)) {
163+ require($(thresholds).length == numClasses, this .getClass.getSimpleName +
164+ " .train() called with non-matching numClasses and thresholds.length." +
165+ s " numClasses= $numClasses, but thresholds has length ${$(thresholds).length}" )
166+ }
167+
161168 val instr = Instrumentation .create(this , oldDataset)
162169 instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
163170 maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
164171 seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
165172 instr.logNumFeatures(numFeatures)
166- instr.logNumClasses(2 )
173+ instr.logNumClasses(numClasses )
167174
168175 val (baseLearners, learnerWeights) = GradientBoostedTrees .run(oldDataset, boostingStrategy,
169176 $(seed))
@@ -202,15 +209,30 @@ class GBTClassificationModel private[ml](
202209 @ Since (" 1.6.0" ) override val uid : String ,
203210 private val _trees : Array [DecisionTreeRegressionModel ],
204211 private val _treeWeights : Array [Double ],
205- @ Since (" 1.6.0" ) override val numFeatures : Int )
206- extends PredictionModel [Vector , GBTClassificationModel ]
212+ @ Since (" 1.6.0" ) override val numFeatures : Int ,
213+ @ Since (" 2.2.0" ) override val numClasses : Int )
214+ extends ProbabilisticClassificationModel [Vector , GBTClassificationModel ]
207215 with GBTClassifierParams with TreeEnsembleModel [DecisionTreeRegressionModel ]
208216 with MLWritable with Serializable {
209217
210218 require(_trees.nonEmpty, " GBTClassificationModel requires at least 1 tree." )
211219 require(_trees.length == _treeWeights.length, " GBTClassificationModel given trees, treeWeights" +
212220 s " of non-matching lengths ( ${_trees.length}, ${_treeWeights.length}, respectively). " )
213221
222+ /**
223+ * Construct a GBTClassificationModel
224+ *
225+ * @param _trees Decision trees in the ensemble.
226+ * @param _treeWeights Weights for the decision trees in the ensemble.
227+ * @param numFeatures The number of features.
228+ */
229+ private [ml] def this (
230+ uid : String ,
231+ _trees : Array [DecisionTreeRegressionModel ],
232+ _treeWeights : Array [Double ],
233+ numFeatures : Int ) =
234+ this (uid, _trees, _treeWeights, numFeatures, 2 )
235+
214236 /**
215237 * Construct a GBTClassificationModel
216238 *
@@ -219,7 +241,7 @@ class GBTClassificationModel private[ml](
219241 */
220242 @ Since (" 1.6.0" )
221243 def this (uid : String , _trees : Array [DecisionTreeRegressionModel ], _treeWeights : Array [Double ]) =
222- this (uid, _trees, _treeWeights, - 1 )
244+ this (uid, _trees, _treeWeights, - 1 , 2 )
223245
224246 @ Since (" 1.4.0" )
225247 override def trees : Array [DecisionTreeRegressionModel ] = _trees
@@ -242,19 +264,37 @@ class GBTClassificationModel private[ml](
242264 }
243265
244266 override protected def predict (features : Vector ): Double = {
245- // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
246- // Classifies by thresholding sum of weighted tree predictions
247- val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
248- val prediction = blas.ddot(numTrees, treePredictions, 1 , _treeWeights, 1 )
249- if (prediction > 0.0 ) 1.0 else 0.0
267+ // If thresholds defined, use predictRaw to get probabilities, otherwise use optimization
268+ if (isDefined(thresholds)) {
269+ super .predict(features)
270+ } else {
271+ if (margin(features) > 0.0 ) 1.0 else 0.0
272+ }
273+ }
274+
275+ override protected def predictRaw (features : Vector ): Vector = {
276+ val prediction : Double = margin(features)
277+ Vectors .dense(Array (- prediction, prediction))
278+ }
279+
280+ override protected def raw2probabilityInPlace (rawPrediction : Vector ): Vector = {
281+ rawPrediction match {
282+ case dv : DenseVector =>
283+ dv.values(0 ) = loss.computeProbability(dv.values(0 ))
284+ dv.values(1 ) = 1.0 - dv.values(0 )
285+ dv
286+ case sv : SparseVector =>
287+ throw new RuntimeException (" Unexpected error in GBTClassificationModel:" +
288+ " raw2probabilityInPlace encountered SparseVector" )
289+ }
250290 }
251291
252292 /** Number of trees in ensemble */
253293 val numTrees : Int = trees.length
254294
255295 @ Since (" 1.4.0" )
256296 override def copy (extra : ParamMap ): GBTClassificationModel = {
257- copyValues(new GBTClassificationModel (uid, _trees, _treeWeights, numFeatures),
297+ copyValues(new GBTClassificationModel (uid, _trees, _treeWeights, numFeatures, numClasses ),
258298 extra).setParent(parent)
259299 }
260300
@@ -276,18 +316,30 @@ class GBTClassificationModel private[ml](
276316 @ Since (" 2.0.0" )
277317 lazy val featureImportances : Vector = TreeEnsembleModel .featureImportances(trees, numFeatures)
278318
319+ /** Raw prediction for the positive class. */
320+ private def margin (features : Vector ): Double = {
321+ val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
322+ blas.ddot(numTrees, treePredictions, 1 , _treeWeights, 1 )
323+ }
324+
279325 /** (private[ml]) Convert to a model in the old API */
280326 private [ml] def toOld : OldGBTModel = {
281327 new OldGBTModel (OldAlgo .Classification , _trees.map(_.toOld), _treeWeights)
282328 }
283329
330+ // hard coded loss, which is not meant to be changed in the model
331+ private val loss = getOldLossType
332+
284333 @ Since (" 2.0.0" )
285334 override def write : MLWriter = new GBTClassificationModel .GBTClassificationModelWriter (this )
286335}
287336
288337@ Since (" 2.0.0" )
289338object GBTClassificationModel extends MLReadable [GBTClassificationModel ] {
290339
340+ private val numFeaturesKey : String = " numFeatures"
341+ private val numTreesKey : String = " numTrees"
342+
291343 @ Since (" 2.0.0" )
292344 override def read : MLReader [GBTClassificationModel ] = new GBTClassificationModelReader
293345
@@ -300,8 +352,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
300352 override protected def saveImpl (path : String ): Unit = {
301353
302354 val extraMetadata : JObject = Map (
303- " numFeatures " -> instance.numFeatures,
304- " numTrees " -> instance.getNumTrees)
355+ numFeaturesKey -> instance.numFeatures,
356+ numTreesKey -> instance.getNumTrees)
305357 EnsembleModelReadWrite .saveImpl(instance, path, sparkSession, extraMetadata)
306358 }
307359 }
@@ -316,8 +368,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
316368 implicit val format = DefaultFormats
317369 val (metadata : Metadata , treesData : Array [(Metadata , Node )], treeWeights : Array [Double ]) =
318370 EnsembleModelReadWrite .loadImpl(path, sparkSession, className, treeClassName)
319- val numFeatures = (metadata.metadata \ " numFeatures " ).extract[Int ]
320- val numTrees = (metadata.metadata \ " numTrees " ).extract[Int ]
371+ val numFeatures = (metadata.metadata \ numFeaturesKey ).extract[Int ]
372+ val numTrees = (metadata.metadata \ numTreesKey ).extract[Int ]
321373
322374 val trees : Array [DecisionTreeRegressionModel ] = treesData.map {
323375 case (treeMetadata, root) =>
@@ -328,7 +380,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
328380 }
329381 require(numTrees == trees.length, s " GBTClassificationModel.load expected $numTrees" +
330382 s " trees based on metadata but found ${trees.length} trees. " )
331- val model = new GBTClassificationModel (metadata.uid, trees, treeWeights, numFeatures)
383+ val model = new GBTClassificationModel (metadata.uid,
384+ trees, treeWeights, numFeatures)
332385 DefaultParamsReader .getAndSetParams(model, metadata)
333386 model
334387 }
@@ -339,14 +392,15 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
339392 oldModel : OldGBTModel ,
340393 parent : GBTClassifier ,
341394 categoricalFeatures : Map [Int , Int ],
342- numFeatures : Int = - 1 ): GBTClassificationModel = {
395+ numFeatures : Int = - 1 ,
396+ numClasses : Int = 2 ): GBTClassificationModel = {
343397 require(oldModel.algo == OldAlgo .Classification , " Cannot convert GradientBoostedTreesModel" +
344398 s " with algo= ${oldModel.algo} (old API) to GBTClassificationModel (new API). " )
345399 val newTrees = oldModel.trees.map { tree =>
346400 // parent for each tree is null since there is no good way to set this.
347401 DecisionTreeRegressionModel .fromOld(tree, null , categoricalFeatures)
348402 }
349403 val uid = if (parent != null ) parent.uid else Identifiable .randomUID(" gbtc" )
350- new GBTClassificationModel (uid, newTrees, oldModel.treeWeights, numFeatures)
404+ new GBTClassificationModel (uid, newTrees, oldModel.treeWeights, numFeatures, numClasses )
351405 }
352406}
0 commit comments