Skip to content
12 changes: 8 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,17 +169,21 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
if ($(predictionCol).nonEmpty) {
val predictUDF = udf { (features: Any) =>
predict(features.asInstanceOf[FeaturesType])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
transformImpl(dataset)
} else {
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
" since no output columns were set.")
dataset
}
}

protected def transformImpl(dataset: DataFrame): DataFrame = {
val predictUDF = udf { (features: Any) =>
predict(features.asInstanceOf[FeaturesType])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

/**
* Predict label for the given features.
* This internal method is used to implement [[transform()]] and output [[predictionCol]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType

/**
* :: Experimental ::
Expand Down Expand Up @@ -177,8 +179,15 @@ final class GBTClassificationModel(

override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: DataFrame): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

override protected def predict(features: Vector): Double = {
// TODO: Override transform() to broadcast model: SPARK-7127
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predict(features))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType

/**
* :: Experimental ::
Expand Down Expand Up @@ -135,8 +137,15 @@ final class RandomForestClassificationModel private[ml] (

override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: DataFrame): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mentioned that we might want to selectively broadcast the model, only if it's large enough. Do you think that is something we can do here automatically, or would it need to be a configuration setting?

val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

override protected def predict(features: Vector): Double = {
// TODO: Override transform() to broadcast model. SPARK-7127
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
// Classifies using majority votes.
// Ignore the weights since all are 1.0 for now.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType

/**
* :: Experimental ::
Expand Down Expand Up @@ -167,8 +169,15 @@ final class GBTRegressionModel(

override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: DataFrame): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

override protected def predict(features: Vector): Double = {
// TODO: Override transform() to broadcast model. SPARK-7127
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predict(features))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType

/**
* :: Experimental ::
Expand Down Expand Up @@ -122,8 +124,15 @@ final class RandomForestRegressionModel private[ml] (

override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: DataFrame): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

override protected def predict(features: Vector): Double = {
// TODO: Override transform() to broadcast model. SPARK-7127
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
Expand Down