Skip to content

Commit 86e73de

Browse files
committed
[SPARK-7127] Replaced deprecated callUDF with udf
1 parent 40a139d commit 86e73de

File tree

4 files changed

+16
-8
lines changed

4 files changed

+16
-8
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,10 @@ final class GBTClassificationModel(
181181

182182
override protected def transformImpl(dataset: DataFrame): DataFrame = {
183183
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
184-
dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType,
185-
col($(featuresCol))))
184+
val predictUDF = udf { (features: Any) =>
185+
bcastModel.value.predict(features.asInstanceOf[Vector])
186+
}
187+
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
186188
}
187189

188190
override protected def predict(features: Vector): Double = {

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,10 @@ final class RandomForestClassificationModel private[ml] (
139139

140140
override protected def transformImpl(dataset: DataFrame): DataFrame = {
141141
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
142-
dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType,
143-
col($(featuresCol))))
142+
val predictUDF = udf { (features: Any) =>
143+
bcastModel.value.predict(features.asInstanceOf[Vector])
144+
}
145+
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
144146
}
145147

146148
override protected def predict(features: Vector): Double = {

mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,10 @@ final class GBTRegressionModel(
171171

172172
override protected def transformImpl(dataset: DataFrame): DataFrame = {
173173
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
174-
dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType,
175-
col($(featuresCol))))
174+
val predictUDF = udf { (features: Any) =>
175+
bcastModel.value.predict(features.asInstanceOf[Vector])
176+
}
177+
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
176178
}
177179

178180
override protected def predict(features: Vector): Double = {

mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,10 @@ final class RandomForestRegressionModel private[ml] (
126126

127127
override protected def transformImpl(dataset: DataFrame): DataFrame = {
128128
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
129-
dataset.withColumn($(predictionCol), callUDF(bcastModel.value.predict _, DoubleType,
130-
col($(featuresCol))))
129+
val predictUDF = udf { (features: Any) =>
130+
bcastModel.value.predict(features.asInstanceOf[Vector])
131+
}
132+
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
131133
}
132134

133135
override protected def predict(features: Vector): Double = {

0 commit comments

Comments
 (0)