Skip to content

Commit 9afad56

Browse files
committed
[SPARK-7127] Simplified calls by overriding transformImpl and using broadcasted model in callUDF to make prediction
1 parent 1f34be4 commit 9afad56

File tree

5 files changed

+34
-50
lines changed

5 files changed

+34
-50
lines changed

mllib/src/main/scala/org/apache/spark/ml/Predictor.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,22 +174,20 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
174174
* @return transformed dataset with [[predictionCol]] of type [[Double]]
175175
*/
176176
override def transform(dataset: DataFrame): DataFrame = {
177-
transformImpl(dataset, predict)
178-
}
179-
180-
protected def transformImpl(
181-
dataset: DataFrame,
182-
predictFunc: (FeaturesType) => Double): DataFrame = {
183177
transformSchema(dataset.schema, logging = true)
184178
if ($(predictionCol).nonEmpty) {
185-
dataset.withColumn($(predictionCol), callUDF(predictFunc, DoubleType, col($(featuresCol))))
179+
transformImpl(dataset)
186180
} else {
187181
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
188182
" since no output columns were set.")
189183
dataset
190184
}
191185
}
192186

187+
protected def transformImpl(dataset: DataFrame): DataFrame = {
188+
dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol))))
189+
}
190+
193191
/**
194192
* Predict label for the given features.
195193
* This internal method is used to implement [[transform()]] and output [[predictionCol]].

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
3434
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
3535
import org.apache.spark.rdd.RDD
3636
import org.apache.spark.sql.DataFrame
37+
import org.apache.spark.sql.functions._
38+
import org.apache.spark.sql.types.DoubleType
3739

3840
/**
3941
* :: AlphaComponent ::
@@ -176,23 +178,17 @@ final class GBTClassificationModel(
176178

177179
override def treeWeights: Array[Double] = _treeWeights
178180

179-
override def transform(dataset: DataFrame): DataFrame = {
181+
override protected def transformImpl(dataset: DataFrame): DataFrame = {
180182
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
181-
val predictFunc = (features: Vector) => predictImpl(features, () => bcastModel.value)
182-
transformImpl(dataset, predictFunc)
183+
dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType,
184+
col($(featuresCol))))
183185
}
184186

185187
override protected def predict(features: Vector): Double = {
186-
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
187-
// Predict without using a broadcasted model
188-
predictImpl(features, () => this)
189-
}
190-
191-
protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = {
188+
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
192189
// Classifies by thresholding sum of weighted tree predictions
193-
val treePredictions = modelAccesor().trees.map(_.rootNode.predict(features))
194-
val prediction = blas.ddot(modelAccesor().numTrees, treePredictions, 1,
195-
modelAccesor().treeWeights, 1)
190+
val treePredictions = _trees.map(_.rootNode.predict(features))
191+
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
196192
if (prediction > 0.0) 1.0 else 0.0
197193
}
198194

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
3131
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
3232
import org.apache.spark.rdd.RDD
3333
import org.apache.spark.sql.DataFrame
34+
import org.apache.spark.sql.functions._
35+
import org.apache.spark.sql.types.DoubleType
3436

3537
/**
3638
* :: AlphaComponent ::
@@ -134,23 +136,18 @@ final class RandomForestClassificationModel private[ml] (
134136

135137
override def treeWeights: Array[Double] = _treeWeights
136138

137-
override def transform(dataset: DataFrame): DataFrame = {
139+
override protected def transformImpl(dataset: DataFrame): DataFrame = {
138140
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
139-
val predictFunc = (features: Vector) => predictImpl(features, () => bcastModel.value)
140-
transformImpl(dataset, predictFunc)
141+
dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType,
142+
col($(featuresCol))))
141143
}
142144

143145
override protected def predict(features: Vector): Double = {
144146
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
145-
// Predict without using a broadcasted model
146-
predictImpl(features, () => this)
147-
}
148-
149-
protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = {
150147
// Classifies using majority votes.
151148
// Ignore the weights since all are 1.0 for now.
152149
val votes = mutable.Map.empty[Int, Double]
153-
modelAccesor().trees.view.foreach { tree =>
150+
_trees.view.foreach { tree =>
154151
val prediction = tree.rootNode.predict(features).toInt
155152
votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight
156153
}

mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss
3333
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
3434
import org.apache.spark.rdd.RDD
3535
import org.apache.spark.sql.DataFrame
36+
import org.apache.spark.sql.functions._
37+
import org.apache.spark.sql.types.DoubleType
3638

3739
/**
3840
* :: AlphaComponent ::
@@ -165,23 +167,17 @@ final class GBTRegressionModel(
165167

166168
override def treeWeights: Array[Double] = _treeWeights
167169

168-
override def transform(dataset: DataFrame): DataFrame = {
170+
override protected def transformImpl(dataset: DataFrame): DataFrame = {
169171
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
170-
val predictFunc = (features: Vector) => predictImpl(features, () => bcastModel.value)
171-
transformImpl(dataset, predictFunc)
172+
dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType,
173+
col($(featuresCol))))
172174
}
173175

174176
override protected def predict(features: Vector): Double = {
175-
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
176-
// Predict without using a broadcasted model
177-
predictImpl(features, () => this)
178-
}
179-
180-
protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = {
177+
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
181178
// Classifies by thresholding sum of weighted tree predictions
182-
val treePredictions = modelAccesor().trees.map(_.rootNode.predict(features))
183-
val prediction = blas.ddot(modelAccesor().numTrees, treePredictions, 1,
184-
modelAccesor().treeWeights, 1)
179+
val treePredictions = _trees.map(_.rootNode.predict(features))
180+
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
185181
if (prediction > 0.0) 1.0 else 0.0
186182
}
187183

mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
2929
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
3030
import org.apache.spark.rdd.RDD
3131
import org.apache.spark.sql.DataFrame
32+
import org.apache.spark.sql.functions._
33+
import org.apache.spark.sql.types.DoubleType
3234

3335
/**
3436
* :: AlphaComponent ::
@@ -121,22 +123,17 @@ final class RandomForestRegressionModel private[ml] (
121123

122124
override def treeWeights: Array[Double] = _treeWeights
123125

124-
override def transform(dataset: DataFrame): DataFrame = {
126+
override protected def transformImpl(dataset: DataFrame): DataFrame = {
125127
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
126-
val predictFunc = (features: Vector) => predictImpl(features, () => bcastModel.value)
127-
transformImpl(dataset, predictFunc)
128+
dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType,
129+
col($(featuresCol))))
128130
}
129131

130132
override protected def predict(features: Vector): Double = {
131-
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
132-
// Predict without using a broadcasted model
133-
predictImpl(features, () => this)
134-
}
135-
136-
protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = {
133+
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
137134
// Predict average of tree predictions.
138135
// Ignore the weights since all are 1.0 for now.
139-
modelAccesor().trees.map(_.rootNode.predict(features)).sum / modelAccesor().numTrees
136+
_trees.map(_.rootNode.predict(features)).sum / numTrees
140137
}
141138

142139
override def copy(extra: ParamMap): RandomForestRegressionModel = {

0 commit comments

Comments
 (0)