From 83904bbd31b8980bbad86c1f74d8308f5dbf837b Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 20 May 2015 16:10:30 -0700 Subject: [PATCH 1/7] [SPARK-7127] Adding broadcast of model before prediction in RandomForestClassifier --- .../scala/org/apache/spark/ml/Predictor.scala | 39 ++++++++++++++++++- .../RandomForestClassifier.scala | 25 ++++++++++-- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index ec0f76aa668b..65324b5f676e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils @@ -176,7 +177,7 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) if ($(predictionCol).nonEmpty) { - dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol)))) + transformImpl(dataset) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + " since no output columns were set.") @@ -184,9 +185,45 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, } } + protected def transformImpl(dataset: DataFrame): DataFrame = { + dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol)))) + } + /** * Predict label for the given features. * This internal method is used to implement [[transform()]] and output [[predictionCol]]. */ protected def predict(features: FeaturesType): Double } + + +/** + * :: DeveloperApi :: + * + * Abstraction for a model for prediction tasks that will broadcast the model used to predict. + * + * @tparam FeaturesType Type of features. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type + * parameter to specify the concrete type for the corresponding model. + */ +@DeveloperApi +abstract class PredictionModelBroadcasting[ + FeaturesType, M <: PredictionModelBroadcasting[FeaturesType, M] + ] + extends PredictionModel[FeaturesType, M] { + + protected def transformImpl(dataset: DataFrame, bcastModel: Broadcast[M]): DataFrame = { + + dataset.withColumn($(predictionCol), + callUDF((features: FeaturesType) => predictWithBroadcastModel(features, bcastModel), + DoubleType, col($(featuresCol))) + ) + } + + /** + * Predict label for the given features using a broadcasted model. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + protected def predictWithBroadcastModel(features: FeaturesType, bcastModel: Broadcast[M]): Double +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index a1de7919859e..71764cf4db1a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -20,7 +20,8 @@ package org.apache.spark.ml.classification import scala.collection.mutable import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.{PredictionModelBroadcasting, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} @@ -109,6 +110,7 @@ object RandomForestClassifier { RandomForestParams.supportedFeatureSubsetStrategies } + /** * :: AlphaComponent :: * @@ -122,7 +124,7 @@ object RandomForestClassifier { final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel]) - extends PredictionModel[Vector, RandomForestClassificationModel] + extends PredictionModelBroadcasting[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") @@ -134,13 +136,28 @@ final class RandomForestClassificationModel private[ml] ( override def treeWeights: Array[Double] = _treeWeights + override def transform(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + transformImpl(dataset, bcastModel) + } + override protected def predict(features: Vector): Double = { - // TODO: Override transform() to broadcast model. SPARK-7127 // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 + // Predict without using a broadcasted mode + predictImpl(features, () => this) + } + + override protected def predictWithBroadcastModel(features: Vector, + bcastModel: Broadcast[RandomForestClassificationModel]): Double = { + // Predict using the given broadcasted model + predictImpl(features, () => bcastModel.value) + } + + protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = { // Classifies using majority votes. // Ignore the weights since all are 1.0 for now. val votes = mutable.Map.empty[Int, Double] - _trees.view.foreach { tree => + modelAccesor().trees.view.foreach { tree => val prediction = tree.rootNode.predict(features).toInt votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight } From aaad77b8a973af6870d2292a33361747cbff97e2 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Sun, 24 May 2015 20:41:47 -0700 Subject: [PATCH 2/7] [SPARK-7127] Removed abstract class for broadcasting model, instead passing a prediction function as param to transform --- .../scala/org/apache/spark/ml/Predictor.scala | 45 +++---------------- .../RandomForestClassifier.scala | 16 +++---- 2 files changed, 12 insertions(+), 49 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 65324b5f676e..9fde12388a52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -18,7 +18,6 @@ package org.apache.spark.ml import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils @@ -175,9 +174,15 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, * @return transformed dataset with [[predictionCol]] of type [[Double]] */ override def transform(dataset: DataFrame): DataFrame = { + transformImpl(dataset, predict) + } + + protected def transformImpl( + dataset: DataFrame, + predictFunc: (FeaturesType) => Double): DataFrame = { transformSchema(dataset.schema, logging = true) if ($(predictionCol).nonEmpty) { - transformImpl(dataset) + dataset.withColumn($(predictionCol), callUDF(predictFunc, DoubleType, col($(featuresCol)))) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + " since no output columns were set.") @@ -185,45 +190,9 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, } } - protected def transformImpl(dataset: DataFrame): DataFrame = { - dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol)))) - } - /** * Predict label for the given features. * This internal method is used to implement [[transform()]] and output [[predictionCol]]. */ protected def predict(features: FeaturesType): Double } - - -/** - * :: DeveloperApi :: - * - * Abstraction for a model for prediction tasks that will broadcast the model used to predict. - * - * @tparam FeaturesType Type of features. - * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. - * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type - * parameter to specify the concrete type for the corresponding model. - */ -@DeveloperApi -abstract class PredictionModelBroadcasting[ - FeaturesType, M <: PredictionModelBroadcasting[FeaturesType, M] - ] - extends PredictionModel[FeaturesType, M] { - - protected def transformImpl(dataset: DataFrame, bcastModel: Broadcast[M]): DataFrame = { - - dataset.withColumn($(predictionCol), - callUDF((features: FeaturesType) => predictWithBroadcastModel(features, bcastModel), - DoubleType, col($(featuresCol))) - ) - } - - /** - * Predict label for the given features using a broadcasted model. - * This internal method is used to implement [[transform()]] and output [[predictionCol]]. - */ - protected def predictWithBroadcastModel(features: FeaturesType, bcastModel: Broadcast[M]): Double -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 71764cf4db1a..d97a7347c5cd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -20,8 +20,7 @@ package org.apache.spark.ml.classification import scala.collection.mutable import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.ml.{PredictionModelBroadcasting, Predictor} +import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} @@ -124,7 +123,7 @@ object RandomForestClassifier { final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel]) - extends PredictionModelBroadcasting[Vector, RandomForestClassificationModel] + extends PredictionModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") @@ -138,21 +137,16 @@ final class RandomForestClassificationModel private[ml] ( override def transform(dataset: DataFrame): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) - transformImpl(dataset, bcastModel) + val predictFunc = (features: Vector) => predictImpl(features, () => bcastModel.value) + transformImpl(dataset, predictFunc) } override protected def predict(features: Vector): Double = { // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 - // Predict without using a broadcasted mode + // Predict without using a broadcasted model predictImpl(features, () => this) } - override protected def predictWithBroadcastModel(features: Vector, - bcastModel: Broadcast[RandomForestClassificationModel]): Double = { - // Predict using the given broadcasted model - predictImpl(features, () => bcastModel.value) - } - protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = { // Classifies using majority votes. // Ignore the weights since all are 1.0 for now. From 6fd153c6d725404110de67c4672d47e1fe21fe2a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 15 Jun 2015 15:21:14 -0700 Subject: [PATCH 3/7] [SPARK-7127] Applied broadcasting to remaining ensemble models --- .../spark/ml/classification/GBTClassifier.scala | 14 ++++++++++++-- .../apache/spark/ml/regression/GBTRegressor.scala | 14 ++++++++++++-- .../ml/regression/RandomForestRegressor.scala | 14 ++++++++++++-- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index d504d84beb91..0fe6a86c21a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -176,9 +176,19 @@ final class GBTClassificationModel( override def treeWeights: Array[Double] = _treeWeights + override def transform(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictFunc = (features: Vector) => predictImpl(features, () => bcastModel.value) + transformImpl(dataset, predictFunc) + } + override protected def predict(features: Vector): Double = { - // TODO: Override transform() to broadcast model: SPARK-7127 - // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 + // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 + // Predict without using a broadcasted model + predictImpl(features, () => this) + } + + protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = { // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predict(features)) val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 4249ff5c1ebc..28e6399cf5cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -165,9 +165,19 @@ final class GBTRegressionModel( override def treeWeights: Array[Double] = _treeWeights + override def transform(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictFunc = (features: Vector) => predictImpl(features, () => bcastModel.value) + transformImpl(dataset, predictFunc) + } + override protected def predict(features: Vector): Double = { - // TODO: Override transform() to broadcast model. SPARK-7127 - // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 + // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 + // Predict without using a broadcasted model + predictImpl(features, () => this) + } + + protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = { // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predict(features)) val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 82437aa8de29..bce8d62079d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -121,9 +121,19 @@ final class RandomForestRegressionModel private[ml] ( override def treeWeights: Array[Double] = _treeWeights + override def transform(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictFunc = (features: Vector) => predictImpl(features, () => bcastModel.value) + transformImpl(dataset, predictFunc) + } + override protected def predict(features: Vector): Double = { - // TODO: Override transform() to broadcast model. SPARK-7127 - // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 + // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 + // Predict without using a broadcasted model + predictImpl(features, () => this) + } + + protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = { // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. _trees.map(_.rootNode.predict(features)).sum / numTrees From 171a6cee7ec83415da0705f2fe15ccca4c3f883c Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 15 Jun 2015 15:35:49 -0700 Subject: [PATCH 4/7] [SPARK-7127] Used modelAccessor parameter in predictImpl to access broadcasted model --- .../org/apache/spark/ml/classification/GBTClassifier.scala | 5 +++-- .../scala/org/apache/spark/ml/regression/GBTRegressor.scala | 5 +++-- .../apache/spark/ml/regression/RandomForestRegressor.scala | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 0fe6a86c21a0..f8345bc7b261 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -190,8 +190,9 @@ final class GBTClassificationModel( protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = { // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predict(features)) - val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + val treePredictions = modelAccesor().trees.map(_.rootNode.predict(features)) + val prediction = blas.ddot(modelAccesor().numTrees, treePredictions, 1, + modelAccesor().treeWeights, 1) if (prediction > 0.0) 1.0 else 0.0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 28e6399cf5cf..245c6530df79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -179,8 +179,9 @@ final class GBTRegressionModel( protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = { // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predict(features)) - val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + val treePredictions = modelAccesor().trees.map(_.rootNode.predict(features)) + val prediction = blas.ddot(modelAccesor().numTrees, treePredictions, 1, + modelAccesor().treeWeights, 1) if (prediction > 0.0) 1.0 else 0.0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index bce8d62079d0..a24e5e29841c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -136,7 +136,7 @@ final class RandomForestRegressionModel private[ml] ( protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = { // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. - _trees.map(_.rootNode.predict(features)).sum / numTrees + modelAccesor().trees.map(_.rootNode.predict(features)).sum / modelAccesor().numTrees } override def copy(extra: ParamMap): RandomForestRegressionModel = { From 1f34be4de921a7088cadb47494a03d0f61e45634 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 15 Jun 2015 15:39:25 -0700 Subject: [PATCH 5/7] [SPARK-7127] Removed accidental newline --- .../apache/spark/ml/classification/RandomForestClassifier.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index d97a7347c5cd..4ba06d0f37a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -109,7 +109,6 @@ object RandomForestClassifier { RandomForestParams.supportedFeatureSubsetStrategies } - /** * :: AlphaComponent :: * From 9afad56fa0a7a07510e27259dbc14805ae63a2bf Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 14 Jul 2015 14:32:30 -0700 Subject: [PATCH 6/7] [SPARK-7127] Simplified calls by overriding transformImpl and using broadcasted model in callUDF to make prediction --- .../scala/org/apache/spark/ml/Predictor.scala | 12 +++++------ .../ml/classification/GBTClassifier.scala | 20 ++++++++----------- .../RandomForestClassifier.scala | 15 ++++++-------- .../spark/ml/regression/GBTRegressor.scala | 20 ++++++++----------- .../ml/regression/RandomForestRegressor.scala | 17 +++++++--------- 5 files changed, 34 insertions(+), 50 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 9fde12388a52..71ccd244c81b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -174,15 +174,9 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, * @return transformed dataset with [[predictionCol]] of type [[Double]] */ override def transform(dataset: DataFrame): DataFrame = { - transformImpl(dataset, predict) - } - - protected def transformImpl( - dataset: DataFrame, - predictFunc: (FeaturesType) => Double): DataFrame = { transformSchema(dataset.schema, logging = true) if ($(predictionCol).nonEmpty) { - dataset.withColumn($(predictionCol), callUDF(predictFunc, DoubleType, col($(featuresCol)))) + transformImpl(dataset) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + " since no output columns were set.") @@ -190,6 +184,10 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, } } + protected def transformImpl(dataset: DataFrame): DataFrame = { + dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol)))) + } + /** * Predict label for the given features. * This internal method is used to implement [[transform()]] and output [[predictionCol]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index f8345bc7b261..07928822295e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -34,6 +34,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType /** * :: AlphaComponent :: @@ -176,23 +178,17 @@ final class GBTClassificationModel( override def treeWeights: Array[Double] = _treeWeights - override def transform(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: DataFrame): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) - val predictFunc = (features: Vector) => predictImpl(features, () => bcastModel.value) - transformImpl(dataset, predictFunc) + dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType, + col($(featuresCol)))) } override protected def predict(features: Vector): Double = { - // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 - // Predict without using a broadcasted model - predictImpl(features, () => this) - } - - protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = { + // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions - val treePredictions = modelAccesor().trees.map(_.rootNode.predict(features)) - val prediction = blas.ddot(modelAccesor().numTrees, treePredictions, 1, - modelAccesor().treeWeights, 1) + val treePredictions = _trees.map(_.rootNode.predict(features)) + val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) if (prediction > 0.0) 1.0 else 0.0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 4ba06d0f37a6..90f228e41836 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -31,6 +31,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType /** * :: AlphaComponent :: @@ -134,23 +136,18 @@ final class RandomForestClassificationModel private[ml] ( override def treeWeights: Array[Double] = _treeWeights - override def transform(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: DataFrame): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) - val predictFunc = (features: Vector) => predictImpl(features, () => bcastModel.value) - transformImpl(dataset, predictFunc) + dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType, + col($(featuresCol)))) } override protected def predict(features: Vector): Double = { // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 - // Predict without using a broadcasted model - predictImpl(features, () => this) - } - - protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = { // Classifies using majority votes. // Ignore the weights since all are 1.0 for now. val votes = mutable.Map.empty[Int, Double] - modelAccesor().trees.view.foreach { tree => + _trees.view.foreach { tree => val prediction = tree.rootNode.predict(features).toInt votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 245c6530df79..cd494fc77b33 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -33,6 +33,8 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType /** * :: AlphaComponent :: @@ -165,23 +167,17 @@ final class GBTRegressionModel( override def treeWeights: Array[Double] = _treeWeights - override def transform(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: DataFrame): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) - val predictFunc = (features: Vector) => predictImpl(features, () => bcastModel.value) - transformImpl(dataset, predictFunc) + dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType, + col($(featuresCol)))) } override protected def predict(features: Vector): Double = { - // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 - // Predict without using a broadcasted model - predictImpl(features, () => this) - } - - protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = { + // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions - val treePredictions = modelAccesor().trees.map(_.rootNode.predict(features)) - val prediction = blas.ddot(modelAccesor().numTrees, treePredictions, 1, - modelAccesor().treeWeights, 1) + val treePredictions = _trees.map(_.rootNode.predict(features)) + val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) if (prediction > 0.0) 1.0 else 0.0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index a24e5e29841c..50a52f055129 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -29,6 +29,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType /** * :: AlphaComponent :: @@ -121,22 +123,17 @@ final class RandomForestRegressionModel private[ml] ( override def treeWeights: Array[Double] = _treeWeights - override def transform(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: DataFrame): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) - val predictFunc = (features: Vector) => predictImpl(features, () => bcastModel.value) - transformImpl(dataset, predictFunc) + dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType, + col($(featuresCol)))) } override protected def predict(features: Vector): Double = { - // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 - // Predict without using a broadcasted model - predictImpl(features, () => this) - } - - protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = { + // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. - modelAccesor().trees.map(_.rootNode.predict(features)).sum / modelAccesor().numTrees + _trees.map(_.rootNode.predict(features)).sum / numTrees } override def copy(extra: ParamMap): RandomForestRegressionModel = { From 86e73ded458c8cc8790c13a98c4a0f3f7efb7dbc Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 14 Jul 2015 17:35:31 -0700 Subject: [PATCH 7/7] [SPARK-7127] Replaced deprecated callUDF with udf --- .../org/apache/spark/ml/classification/GBTClassifier.scala | 6 ++++-- .../spark/ml/classification/RandomForestClassifier.scala | 6 ++++-- .../scala/org/apache/spark/ml/regression/GBTRegressor.scala | 6 ++++-- .../apache/spark/ml/regression/RandomForestRegressor.scala | 6 ++++-- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 3f6800971656..eb0b1a0a405f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -181,8 +181,10 @@ final class GBTClassificationModel( override protected def transformImpl(dataset: DataFrame): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) - dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType, - col($(featuresCol)))) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } override protected def predict(features: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index a5ea148dd02a..991b551c4338 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -139,8 +139,10 @@ final class RandomForestClassificationModel private[ml] ( override protected def transformImpl(dataset: DataFrame): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) - dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType, - col($(featuresCol)))) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } override protected def predict(features: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index b46d1f43b93b..e38dc73ee0ba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -171,8 +171,10 @@ final class GBTRegressionModel( override protected def transformImpl(dataset: DataFrame): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) - dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType, - col($(featuresCol)))) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } override protected def predict(features: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 2aaebdc183d2..745637a88d6c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -126,8 +126,10 @@ final class RandomForestRegressionModel private[ml] ( override protected def transformImpl(dataset: DataFrame): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) - dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType, - col($(featuresCol)))) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } override protected def predict(features: Vector): Double = {