|
18 | 18 | package org.apache.spark.ml |
19 | 19 |
|
20 | 20 | import org.apache.spark.annotation.DeveloperApi |
21 | | -import org.apache.spark.broadcast.Broadcast |
22 | 21 | import org.apache.spark.ml.param._ |
23 | 22 | import org.apache.spark.ml.param.shared._ |
24 | 23 | import org.apache.spark.ml.util.SchemaUtils |
@@ -175,55 +174,25 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, |
175 | 174 | * @return transformed dataset with [[predictionCol]] of type [[Double]] |
176 | 175 | */ |
177 | 176 | override def transform(dataset: DataFrame): DataFrame = { |
| 177 | + transformImpl(dataset, predict) |
| 178 | + } |
| 179 | + |
| 180 | + protected def transformImpl( |
| 181 | + dataset: DataFrame, |
| 182 | + predictFunc: (FeaturesType) => Double): DataFrame = { |
178 | 183 | transformSchema(dataset.schema, logging = true) |
179 | 184 | if ($(predictionCol).nonEmpty) { |
180 | | - transformImpl(dataset) |
| 185 | + dataset.withColumn($(predictionCol), callUDF(predictFunc, DoubleType, col($(featuresCol)))) |
181 | 186 | } else { |
182 | 187 | this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + |
183 | 188 | " since no output columns were set.") |
184 | 189 | dataset |
185 | 190 | } |
186 | 191 | } |
187 | 192 |
|
188 | | - protected def transformImpl(dataset: DataFrame): DataFrame = { |
189 | | - dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol)))) |
190 | | - } |
191 | | - |
192 | 193 | /** |
193 | 194 | * Predict label for the given features. |
194 | 195 | * This internal method is used to implement [[transform()]] and output [[predictionCol]]. |
195 | 196 | */ |
196 | 197 | protected def predict(features: FeaturesType): Double |
197 | 198 | } |
198 | | - |
199 | | - |
200 | | -/** |
201 | | - * :: DeveloperApi :: |
202 | | - * |
203 | | - * Abstraction for a model for prediction tasks that will broadcast the model used to predict. |
204 | | - * |
205 | | - * @tparam FeaturesType Type of features. |
206 | | - * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. |
207 | | - * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type |
208 | | - * parameter to specify the concrete type for the corresponding model. |
209 | | - */ |
210 | | -@DeveloperApi |
211 | | -abstract class PredictionModelBroadcasting[ |
212 | | - FeaturesType, M <: PredictionModelBroadcasting[FeaturesType, M] |
213 | | - ] |
214 | | - extends PredictionModel[FeaturesType, M] { |
215 | | - |
216 | | - protected def transformImpl(dataset: DataFrame, bcastModel: Broadcast[M]): DataFrame = { |
217 | | - |
218 | | - dataset.withColumn($(predictionCol), |
219 | | - callUDF((features: FeaturesType) => predictWithBroadcastModel(features, bcastModel), |
220 | | - DoubleType, col($(featuresCol))) |
221 | | - ) |
222 | | - } |
223 | | - |
224 | | - /** |
225 | | - * Predict label for the given features using a broadcasted model. |
226 | | - * This internal method is used to implement [[transform()]] and output [[predictionCol]]. |
227 | | - */ |
228 | | - protected def predictWithBroadcastModel(features: FeaturesType, bcastModel: Broadcast[M]): Double |
229 | | -} |
0 commit comments