Skip to content

Commit aaad77b

Browse files
committed
[SPARK-7127] Removed abstract class for broadcasting model, instead passing a prediction function as param to transform
1 parent 83904bb commit aaad77b

File tree

2 files changed

+12
-49
lines changed

2 files changed

+12
-49
lines changed

mllib/src/main/scala/org/apache/spark/ml/Predictor.scala

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.ml
1919

2020
import org.apache.spark.annotation.DeveloperApi
21-
import org.apache.spark.broadcast.Broadcast
2221
import org.apache.spark.ml.param._
2322
import org.apache.spark.ml.param.shared._
2423
import org.apache.spark.ml.util.SchemaUtils
@@ -175,55 +174,25 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
175174
* @return transformed dataset with [[predictionCol]] of type [[Double]]
176175
*/
177176
override def transform(dataset: DataFrame): DataFrame = {
177+
transformImpl(dataset, predict)
178+
}
179+
180+
protected def transformImpl(
181+
dataset: DataFrame,
182+
predictFunc: (FeaturesType) => Double): DataFrame = {
178183
transformSchema(dataset.schema, logging = true)
179184
if ($(predictionCol).nonEmpty) {
180-
transformImpl(dataset)
185+
dataset.withColumn($(predictionCol), callUDF(predictFunc, DoubleType, col($(featuresCol))))
181186
} else {
182187
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
183188
" since no output columns were set.")
184189
dataset
185190
}
186191
}
187192

188-
protected def transformImpl(dataset: DataFrame): DataFrame = {
189-
dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol))))
190-
}
191-
192193
/**
193194
* Predict label for the given features.
194195
* This internal method is used to implement [[transform()]] and output [[predictionCol]].
195196
*/
196197
protected def predict(features: FeaturesType): Double
197198
}
198-
199-
200-
/**
201-
* :: DeveloperApi ::
202-
*
203-
* Abstraction for a model for prediction tasks that will broadcast the model used to predict.
204-
*
205-
* @tparam FeaturesType Type of features.
206-
* E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
207-
* @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
208-
* parameter to specify the concrete type for the corresponding model.
209-
*/
210-
@DeveloperApi
211-
abstract class PredictionModelBroadcasting[
212-
FeaturesType, M <: PredictionModelBroadcasting[FeaturesType, M]
213-
]
214-
extends PredictionModel[FeaturesType, M] {
215-
216-
protected def transformImpl(dataset: DataFrame, bcastModel: Broadcast[M]): DataFrame = {
217-
218-
dataset.withColumn($(predictionCol),
219-
callUDF((features: FeaturesType) => predictWithBroadcastModel(features, bcastModel),
220-
DoubleType, col($(featuresCol)))
221-
)
222-
}
223-
224-
/**
225-
* Predict label for the given features using a broadcasted model.
226-
* This internal method is used to implement [[transform()]] and output [[predictionCol]].
227-
*/
228-
protected def predictWithBroadcastModel(features: FeaturesType, bcastModel: Broadcast[M]): Double
229-
}

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ package org.apache.spark.ml.classification
2020
import scala.collection.mutable
2121

2222
import org.apache.spark.annotation.AlphaComponent
23-
import org.apache.spark.broadcast.Broadcast
24-
import org.apache.spark.ml.{PredictionModelBroadcasting, Predictor}
23+
import org.apache.spark.ml.{PredictionModel, Predictor}
2524
import org.apache.spark.ml.param.ParamMap
2625
import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
2726
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
@@ -124,7 +123,7 @@ object RandomForestClassifier {
124123
final class RandomForestClassificationModel private[ml] (
125124
override val uid: String,
126125
private val _trees: Array[DecisionTreeClassificationModel])
127-
extends PredictionModelBroadcasting[Vector, RandomForestClassificationModel]
126+
extends PredictionModel[Vector, RandomForestClassificationModel]
128127
with TreeEnsembleModel with Serializable {
129128

130129
require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
@@ -138,21 +137,16 @@ final class RandomForestClassificationModel private[ml] (
138137

139138
override def transform(dataset: DataFrame): DataFrame = {
140139
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
141-
transformImpl(dataset, bcastModel)
140+
val predictFunc = (features: Vector) => predictImpl(features, () => bcastModel.value)
141+
transformImpl(dataset, predictFunc)
142142
}
143143

144144
override protected def predict(features: Vector): Double = {
145145
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
146-
// Predict without using a broadcasted mode
146+
// Predict without using a broadcasted model
147147
predictImpl(features, () => this)
148148
}
149149

150-
override protected def predictWithBroadcastModel(features: Vector,
151-
bcastModel: Broadcast[RandomForestClassificationModel]): Double = {
152-
// Predict using the given broadcasted model
153-
predictImpl(features, () => bcastModel.value)
154-
}
155-
156150
protected def predictImpl(features: Vector, modelAccesor: () => TreeEnsembleModel): Double = {
157151
// Classifies using majority votes.
158152
// Ignore the weights since all are 1.0 for now.

0 commit comments

Comments
 (0)