diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 568cdd11a12a..b6b02e77909b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -210,6 +210,9 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur outputData.toDF } + final override def transformImpl(dataset: Dataset[_]): DataFrame = + throw new UnsupportedOperationException(s"transformImpl is not supported in $getClass") + /** * Predict label for the given features. * This method is used to implement `transform()` and output [[predictionCol]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index a5ed4a38a886..5e0c66b3ab41 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions._ /** @@ -286,14 +286,6 @@ class GBTClassificationModel private[ml]( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val bcastModel = dataset.sparkSession.sparkContext.broadcast(this) - val predictUDF = udf { (features: Any) => - bcastModel.value.predict(features.asInstanceOf[Vector]) - } - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) - } - override def predict(features: Vector): Double = { // If thresholds defined, use predictRaw to get probabilities, otherwise use optimization if (isDefined(thresholds)) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 3500f2ad52a5..4424319a4c63 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ -import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.tree.{TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ import org.apache.spark.ml.util.{Identifiable, MetadataUtils} @@ -34,8 +34,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.Dataset /** * Random Forest learning algorithm for @@ -208,14 +207,6 @@ class RandomForestClassificationModel private[ml] ( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val bcastModel = dataset.sparkSession.sparkContext.broadcast(this) - val predictUDF = udf { (features: Any) => - bcastModel.value.predict(features.asInstanceOf[Vector]) - } - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) - } - override protected def predictRaw(features: Vector): Vector = { // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 // Classifies using majority votes.