Skip to content

Commit 8b8be1f

Browse files
BryanCutlerjkbradley
authored andcommitted
[SPARK-7127] [MLLIB] Adding broadcast of model before prediction for ensembles
Broadcast of ensemble models in transformImpl before call to predict Author: Bryan Cutler <[email protected]> Closes apache#6300 from BryanCutler/bcast-ensemble-models-7127 and squashes the following commits: 86e73de [Bryan Cutler] [SPARK-7127] Replaced deprecated callUDF with udf 40a139d [Bryan Cutler] Merge branch 'master' into bcast-ensemble-models-7127 9afad56 [Bryan Cutler] [SPARK-7127] Simplified calls by overriding transformImpl and using broadcasted model in callUDF to make prediction 1f34be4 [Bryan Cutler] [SPARK-7127] Removed accidental newline 171a6ce [Bryan Cutler] [SPARK-7127] Used modelAccessor parameter in predictImpl to access broadcasted model 6fd153c [Bryan Cutler] [SPARK-7127] Applied broadcasting to remaining ensemble models aaad77b [Bryan Cutler] [SPARK-7127] Removed abstract class for broadcasting model, instead passing a prediction function as param to transform 83904bb [Bryan Cutler] [SPARK-7127] Adding broadcast of model before prediction in RandomForestClassifier
1 parent 830666f commit 8b8be1f

File tree

5 files changed

+48
-8
lines changed

5 files changed

+48
-8
lines changed

mllib/src/main/scala/org/apache/spark/ml/Predictor.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,17 +169,21 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
169169
override def transform(dataset: DataFrame): DataFrame = {
170170
transformSchema(dataset.schema, logging = true)
171171
if ($(predictionCol).nonEmpty) {
172-
val predictUDF = udf { (features: Any) =>
173-
predict(features.asInstanceOf[FeaturesType])
174-
}
175-
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
172+
transformImpl(dataset)
176173
} else {
177174
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
178175
" since no output columns were set.")
179176
dataset
180177
}
181178
}
182179

180+
protected def transformImpl(dataset: DataFrame): DataFrame = {
181+
val predictUDF = udf { (features: Any) =>
182+
predict(features.asInstanceOf[FeaturesType])
183+
}
184+
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
185+
}
186+
183187
/**
184188
* Predict label for the given features.
185189
* This internal method is used to implement [[transform()]] and output [[predictionCol]].

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
3434
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
3535
import org.apache.spark.rdd.RDD
3636
import org.apache.spark.sql.DataFrame
37+
import org.apache.spark.sql.functions._
38+
import org.apache.spark.sql.types.DoubleType
3739

3840
/**
3941
* :: Experimental ::
@@ -177,8 +179,15 @@ final class GBTClassificationModel(
177179

178180
override def treeWeights: Array[Double] = _treeWeights
179181

182+
override protected def transformImpl(dataset: DataFrame): DataFrame = {
183+
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
184+
val predictUDF = udf { (features: Any) =>
185+
bcastModel.value.predict(features.asInstanceOf[Vector])
186+
}
187+
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
188+
}
189+
180190
override protected def predict(features: Vector): Double = {
181-
// TODO: Override transform() to broadcast model: SPARK-7127
182191
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
183192
// Classifies by thresholding sum of weighted tree predictions
184193
val treePredictions = _trees.map(_.rootNode.predict(features))

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
3131
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
3232
import org.apache.spark.rdd.RDD
3333
import org.apache.spark.sql.DataFrame
34+
import org.apache.spark.sql.functions._
35+
import org.apache.spark.sql.types.DoubleType
3436

3537
/**
3638
* :: Experimental ::
@@ -143,8 +145,15 @@ final class RandomForestClassificationModel private[ml] (
143145

144146
override def treeWeights: Array[Double] = _treeWeights
145147

148+
override protected def transformImpl(dataset: DataFrame): DataFrame = {
149+
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
150+
val predictUDF = udf { (features: Any) =>
151+
bcastModel.value.predict(features.asInstanceOf[Vector])
152+
}
153+
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
154+
}
155+
146156
override protected def predict(features: Vector): Double = {
147-
// TODO: Override transform() to broadcast model. SPARK-7127
148157
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
149158
// Classifies using majority votes.
150159
// Ignore the weights since all are 1.0 for now.

mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss
3333
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
3434
import org.apache.spark.rdd.RDD
3535
import org.apache.spark.sql.DataFrame
36+
import org.apache.spark.sql.functions._
37+
import org.apache.spark.sql.types.DoubleType
3638

3739
/**
3840
* :: Experimental ::
@@ -167,8 +169,15 @@ final class GBTRegressionModel(
167169

168170
override def treeWeights: Array[Double] = _treeWeights
169171

172+
override protected def transformImpl(dataset: DataFrame): DataFrame = {
173+
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
174+
val predictUDF = udf { (features: Any) =>
175+
bcastModel.value.predict(features.asInstanceOf[Vector])
176+
}
177+
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
178+
}
179+
170180
override protected def predict(features: Vector): Double = {
171-
// TODO: Override transform() to broadcast model. SPARK-7127
172181
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
173182
// Classifies by thresholding sum of weighted tree predictions
174183
val treePredictions = _trees.map(_.rootNode.predict(features))

mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
2929
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
3030
import org.apache.spark.rdd.RDD
3131
import org.apache.spark.sql.DataFrame
32+
import org.apache.spark.sql.functions._
33+
import org.apache.spark.sql.types.DoubleType
3234

3335
/**
3436
* :: Experimental ::
@@ -129,8 +131,15 @@ final class RandomForestRegressionModel private[ml] (
129131

130132
override def treeWeights: Array[Double] = _treeWeights
131133

134+
override protected def transformImpl(dataset: DataFrame): DataFrame = {
135+
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
136+
val predictUDF = udf { (features: Any) =>
137+
bcastModel.value.predict(features.asInstanceOf[Vector])
138+
}
139+
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
140+
}
141+
132142
override protected def predict(features: Vector): Double = {
133-
// TODO: Override transform() to broadcast model. SPARK-7127
134143
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
135144
// Predict average of tree predictions.
136145
// Ignore the weights since all are 1.0 for now.

0 commit comments

Comments
 (0)