Skip to content

Commit 721860d

Browse files
committed
First attempt to resolve issue with inferring func types in 2.12 by instead using info captured when UDF is registered -- capturing which types are nullable (i.e. not primitive)
1 parent 9714fa5 commit 721860d

File tree

23 files changed

+180
-148
lines changed

23 files changed

+180
-148
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
239239
}
240240
// Close ServerSocket on task completion.
241241
serverSocket.foreach { server =>
242-
context.addTaskCompletionListener(_ => server.close())
242+
context.addTaskCompletionListener[Unit](_ => server.close())
243243
}
244244
val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
245245
if (boundPort == -1) {

mllib/src/main/scala/org/apache/spark/ml/Predictor.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
211211
}
212212

213213
protected def transformImpl(dataset: Dataset[_]): DataFrame = {
214-
val predictUDF = udf { (features: Any) =>
215-
predict(features.asInstanceOf[FeaturesType])
214+
val predictUDF = udfInternal { features: FeaturesType =>
215+
predict(features)
216216
}
217217
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
218218
}

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.ml.classification
1919

2020
import org.apache.spark.SparkException
21-
import org.apache.spark.annotation.{DeveloperApi, Since}
21+
import org.apache.spark.annotation.DeveloperApi
2222
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
2323
import org.apache.spark.ml.feature.LabeledPoint
2424
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
@@ -164,8 +164,8 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
164164
var outputData = dataset
165165
var numColsOutput = 0
166166
if (getRawPredictionCol != "") {
167-
val predictRawUDF = udf { (features: Any) =>
168-
predictRaw(features.asInstanceOf[FeaturesType])
167+
val predictRawUDF = udfInternal { features: FeaturesType =>
168+
predictRaw(features)
169169
}
170170
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
171171
numColsOutput += 1
@@ -174,8 +174,8 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
174174
val predUDF = if (getRawPredictionCol != "") {
175175
udf(raw2prediction _).apply(col(getRawPredictionCol))
176176
} else {
177-
val predictUDF = udf { (features: Any) =>
178-
predict(features.asInstanceOf[FeaturesType])
177+
val predictUDF = udfInternal { features: FeaturesType =>
178+
predict(features)
179179
}
180180
predictUDF(col(getFeaturesCol))
181181
}

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,8 @@ class GBTClassificationModel private[ml](
287287

288288
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
289289
val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
290-
val predictUDF = udf { (features: Any) =>
291-
bcastModel.value.predict(features.asInstanceOf[Vector])
290+
val predictUDF = udfInternal { features: Vector =>
291+
bcastModel.value.predict(features)
292292
}
293293
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
294294
}

mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ abstract class ProbabilisticClassificationModel[
113113
var outputData = dataset
114114
var numColsOutput = 0
115115
if ($(rawPredictionCol).nonEmpty) {
116-
val predictRawUDF = udf { (features: Any) =>
117-
predictRaw(features.asInstanceOf[FeaturesType])
116+
val predictRawUDF = udfInternal { features: FeaturesType =>
117+
predictRaw(features)
118118
}
119119
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
120120
numColsOutput += 1
@@ -123,8 +123,8 @@ abstract class ProbabilisticClassificationModel[
123123
val probUDF = if ($(rawPredictionCol).nonEmpty) {
124124
udf(raw2probability _).apply(col($(rawPredictionCol)))
125125
} else {
126-
val probabilityUDF = udf { (features: Any) =>
127-
predictProbability(features.asInstanceOf[FeaturesType])
126+
val probabilityUDF = udfInternal { features: FeaturesType =>
127+
predictProbability(features)
128128
}
129129
probabilityUDF(col($(featuresCol)))
130130
}
@@ -137,8 +137,8 @@ abstract class ProbabilisticClassificationModel[
137137
} else if ($(probabilityCol).nonEmpty) {
138138
udf(probability2prediction _).apply(col($(probabilityCol)))
139139
} else {
140-
val predictUDF = udf { (features: Any) =>
141-
predict(features.asInstanceOf[FeaturesType])
140+
val predictUDF = udfInternal { features: FeaturesType =>
141+
predict(features)
142142
}
143143
predictUDF(col($(featuresCol)))
144144
}

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ class RandomForestClassificationModel private[ml] (
209209

210210
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
211211
val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
212-
val predictUDF = udf { (features: Any) =>
213-
bcastModel.value.predict(features.asInstanceOf[Vector])
212+
val predictUDF = udfInternal { features: Vector =>
213+
bcastModel.value.predict(features)
214214
}
215215
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
216216
}

mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme
165165
}
166166
}
167167

168-
val hashFeatures = udf { row: Row =>
168+
val hashFeatures = udfInternal { row: Row =>
169169
val map = new OpenHashMap[Int, Double]()
170170
localInputCols.foreach { colName =>
171171
val fieldIndex = row.fieldIndex(colName)

mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
2525
import org.apache.spark.ml.util._
2626
import org.apache.spark.mllib.feature
2727
import org.apache.spark.sql.{DataFrame, Dataset}
28-
import org.apache.spark.sql.functions.{col, udf}
28+
import org.apache.spark.sql.functions.{col, udfInternal}
2929
import org.apache.spark.sql.types.{ArrayType, StructType}
3030

3131
/**
@@ -95,7 +95,7 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
9595
val outputSchema = transformSchema(dataset.schema)
9696
val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
9797
// TODO: Make the hashingTF.transform natively in ml framework to avoid extra conversion.
98-
val t = udf { terms: Seq[_] => hashingTF.transform(terms).asML }
98+
val t = udfInternal { terms: Seq[_] => hashingTF.transform(terms).asML }
9999
val metadata = outputSchema($(outputCol)).metadata
100100
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
101101
}

mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext
7373
val featureEncoders = getFeatureEncoders(inputFeatures)
7474
val featureAttrs = getFeatureAttrs(inputFeatures)
7575

76-
def interactFunc = udf { row: Row =>
76+
def interactFunc = udfInternal { row: Row =>
7777
var indices = ArrayBuilder.make[Int]
7878
var values = ArrayBuilder.make[Double]
7979
var size = 1

mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
140140
case VectorAssembler.ERROR_INVALID => (dataset, false)
141141
}
142142
// Data transformation.
143-
val assembleFunc = udf { r: Row =>
143+
val assembleFunc = udfInternal { r: Row =>
144144
VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*)
145145
}.asNondeterministic()
146146
val args = $(inputCols).map { c =>

0 commit comments

Comments
 (0)