Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sql-migration-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ license: |

- Since Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they match to the pattern defined by the JSON option `timestampFormat`. Set JSON option `inferTimestamp` to `false` to disable such type inferring.

- In Spark version 2.4 and earlier, if `org.apache.spark.sql.functions.udf(Any, DataType)` gets a Scala closure with primitive-type argument, the returned UDF will return null if the input values is null. Since Spark 3.0, the UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will return null in Spark 2.4 and earlier if column `x` is null, and return 0 in Spark 3.0. This behavior change is introduced because Spark 3.0 is built with Scala 2.12 by default.
- Since Spark 3.0, using `org.apache.spark.sql.functions.udf(AnyRef, DataType)` is not allowed by default. Set `spark.sql.legacy.allowUntypedScalaUDF` to true to keep using it. But please note that, in Spark version 2.4 and earlier, if `org.apache.spark.sql.functions.udf(AnyRef, DataType)` gets a Scala closure with primitive-type argument, the returned UDF will return null if the input values is null. However, since Spark 3.0, the UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will return null in Spark 2.4 and earlier if column `x` is null, and return 0 in Spark 3.0. This behavior change is introduced because Spark 3.0 is built with Scala 2.12 by default.

- Since Spark 3.0, Proleptic Gregorian calendar is used in parsing, formatting, and converting dates and timestamps as well as in extracting sub-components like years, days and etc. Spark 3.0 uses Java 8 API classes from the java.time packages that based on ISO chronology (https://docs.oracle.com/javase/8/docs/api/java/time/chrono/IsoChronology.html). In Spark version 2.4 and earlier, those operations are performed by using the hybrid calendar (Julian + Gregorian, see https://docs.oracle.com/javase/7/docs/api/java/util/GregorianCalendar.html). The changes impact on the results for dates before October 15, 1582 (Gregorian) and affect on the following Spark 3.0 API:

Expand Down
5 changes: 3 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml

import scala.annotation.varargs
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -79,7 +80,7 @@ abstract class Transformer extends PipelineStage {
* result as a new column.
*/
@DeveloperApi
abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
abstract class UnaryTransformer[IN: TypeTag, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeTag is required for typed UDF when create udf for createTransformFunc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change, but I think it's better than silent result changing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can avoid this breaking change if we know that the type parameter won't be primitive types. cc @srowen @zhengruifeng

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't disagree, but this is trading a possible error for a definite error. In light of the recent conversations about not-breaking things, is this wise? (I don't object though.)

Yes, let's restrict this to primitive types. I think Spark ML even uses some UDFs that accept AnyRef or something to work with tuples or triples, IIRC.

Copy link
Contributor

@cloud-fan cloud-fan Feb 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a developer API, so I'm wondering if third-party implementations would use primitive type and hit the silent result changing.

I think it's better to ask users to re-compile their Spark application than just telling them that they may hit result changinng.

extends Transformer with HasInputCol with HasOutputCol with Logging {

/** @group setParam */
Expand Down Expand Up @@ -118,7 +119,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]

override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema, logging = true)
val transformUDF = udf(this.createTransformFunc, outputDataType)
val transformUDF = udf(this.createTransformFunc)
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))),
outputSchema($(outputCol)).metadata)
}
Expand Down
11 changes: 5 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val transformUDF = udf(hashFunction(_: Vector), DataTypes.createArrayType(new VectorUDT))
val transformUDF = udf(hashFunction(_: Vector))
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
}

Expand Down Expand Up @@ -128,14 +128,13 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
}

// In the origin dataset, find the hash value that hash the same bucket with the key
val sameBucketWithKeyUDF = udf((x: Seq[Vector]) =>
sameBucket(x, keyHash), DataTypes.BooleanType)
val sameBucketWithKeyUDF = udf((x: Seq[Vector]) => sameBucket(x, keyHash))

modelDataset.filter(sameBucketWithKeyUDF(col($(outputCol))))
} else {
// In the origin dataset, find the hash value that is closest to the key
// Limit the use of hashDist since it's controversial
val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash), DataTypes.DoubleType)
val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash))
val hashDistCol = hashDistUDF(col($(outputCol)))
val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol)

Expand Down Expand Up @@ -172,7 +171,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
}

// Get the top k nearest neighbor by their distance to the key
val keyDistUDF = udf((x: Vector) => keyDistance(x, key), DataTypes.DoubleType)
val keyDistUDF = udf((x: Vector) => keyDistance(x, key))
val modelSubsetWithDistCol = modelSubset.withColumn(distCol, keyDistUDF(col($(inputCol))))
modelSubsetWithDistCol.sort(distCol).limit(numNearestNeighbors)
}
Expand Down Expand Up @@ -290,7 +289,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
.drop(explodeCols: _*).distinct()

// Add a new column to store the distance of the two rows.
val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y), DataTypes.DoubleType)
val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y))
val joinedDatasetWithDist = joinedDataset.select(col("*"),
distUDF(col(s"$leftColName.${$(inputCol)}"), col(s"$rightColName.${$(inputCol)}")).as(distCol)
)
Expand Down
11 changes: 7 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.HasPredictionCol
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules,
FPGrowth => MLlibFPGrowth}
import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, FPGrowth => MLlibFPGrowth}
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -286,14 +286,17 @@ class FPGrowthModel private[ml] (

val dt = dataset.schema($(itemsCol)).dataType
// For each rule, examine the input items and summarize the consequents
val predictUDF = udf((items: Seq[Any]) => {
val predictUDF = SparkUserDefinedFunction((items: Seq[Any]) => {
if (items != null) {
val itemset = items.toSet
brRules.value.filter(_._1.forall(itemset.contains))
.flatMap(_._2.filter(!itemset.contains(_))).distinct
} else {
Seq.empty
}}, dt)
}},
dt,
Nil
)
dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol))))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ private[ml] object LSHTest {

// Perform a cross join and label each pair of same_bucket and distance
val pairs = transformedData.as("a").crossJoin(transformedData.as("b"))
val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType)
val sameBucket = udf((x: Seq[Vector], y: Seq[Vector]) => model.hashDistance(x, y) == 0.0,
DataTypes.BooleanType)
val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y))
val sameBucket = udf((x: Seq[Vector], y: Seq[Vector]) => model.hashDistance(x, y) == 0.0)
val result = pairs
.withColumn("same_bucket", sameBucket(col(s"a.$outputCol"), col(s"b.$outputCol")))
.withColumn("distance", distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")))
Expand Down Expand Up @@ -110,7 +109,7 @@ private[ml] object LSHTest {
val model = lsh.fit(dataset)

// Compute expected
val distUDF = udf((x: Vector) => model.keyDistance(x, key), DataTypes.DoubleType)
val distUDF = udf((x: Vector) => model.keyDistance(x, key))
val expected = dataset.sort(distUDF(col(model.getInputCol))).limit(k)

// Compute actual
Expand Down Expand Up @@ -148,7 +147,7 @@ private[ml] object LSHTest {
val inputCol = model.getInputCol

// Compute expected
val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType)
val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y))
val expected = datasetA.as("a").crossJoin(datasetB.as("b"))
.filter(distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")) < threshold)

Expand Down
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.getRuns"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.setRuns"),

// [SPARK-26580][SQL][ML][FOLLOW-UP] Throw exception when use untyped UDF by default
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.UnaryTransformer.this"),

// [SPARK-27090][CORE] Removing old LEGACY_DRIVER_IDENTIFIER ("<driver>")
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.LEGACY_DRIVER_IDENTIFIER"),
// [SPARK-25838] Remove formatVersion from Saveable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_ALLOW_UNTYPED_SCALA_UDF =
buildConf("spark.sql.legacy.allowUntypedScalaUDF")
.internal()
.doc("When set to true, user is allowed to use org.apache.spark.sql.functions." +
"udf(f: AnyRef, dataType: DataType). Otherwise, exception will be throw.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exception will be throw
->
an exception will be thrown at runtime.

.booleanConf
.createWithDefault(false)

val TRUNCATE_TABLE_IGNORE_PERMISSION_ACL =
buildConf("spark.sql.truncateTable.ignorePermissionAcl.enabled")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ sealed abstract class UserDefinedFunction {
def asNondeterministic(): UserDefinedFunction
}

private[sql] case class SparkUserDefinedFunction(
private[spark] case class SparkUserDefinedFunction(
f: AnyRef,
dataType: DataType,
inputSchemas: Seq[Option[ScalaReflection.Schema]],
Expand Down
9 changes: 9 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4732,6 +4732,15 @@ object functions {
* @since 2.0.0
*/
def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = {
if (!SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF)) {
val errorMsg = "You're using untyped Scala UDF, which does not have the input type " +
"information. Spark may blindly pass null to the Scala closure with primitive-type " +
"argument, and the closure will see the default value of the Java type for the null " +
"argument, e.g. `udf((x: Int) => x, IntegerType)`, the result is 0 for null input. " +
"You could use other typed Scala UDF APIs to avoid this problem, or set " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the error message, we should give an example to show how to use the typed Scala UDF for implementing "udf((x: Int) => x, IntegerType)"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

s"${SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key} to true and use this API with caution."
throw new AnalysisException(errorMsg)
}
SparkUserDefinedFunction(f, dataType, inputSchemas = Nil)
}

Expand Down
33 changes: 21 additions & 12 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,12 @@ class UDFSuite extends QueryTest with SharedSparkSession {
assert(df1.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic))
assert(df1.head().getDouble(0) >= 0.0)

val bar = udf(() => Math.random(), DataTypes.DoubleType).asNondeterministic()
val df2 = testData.select(bar())
assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic))
assert(df2.head().getDouble(0) >= 0.0)
withSQLConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key -> "true") {
val bar = udf(() => Math.random(), DataTypes.DoubleType).asNondeterministic()
val df2 = testData.select(bar())
assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic))
assert(df2.head().getDouble(0) >= 0.0)
}

val javaUdf = udf(new UDF0[Double] {
override def call(): Double = Math.random()
Expand Down Expand Up @@ -441,16 +443,23 @@ class UDFSuite extends QueryTest with SharedSparkSession {
}

test("SPARK-25044 Verify null input handling for primitive types - with udf(Any, DataType)") {
val f = udf((x: Int) => x, IntegerType)
checkAnswer(
Seq(Integer.valueOf(1), null).toDF("x").select(f($"x")),
Row(1) :: Row(0) :: Nil)
withSQLConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key -> "true") {
val f = udf((x: Int) => x, IntegerType)
checkAnswer(
Seq(Integer.valueOf(1), null).toDF("x").select(f($"x")),
Row(1) :: Row(0) :: Nil)

val f2 = udf((x: Double) => x, DoubleType)
checkAnswer(
Seq(java.lang.Double.valueOf(1.1), null).toDF("x").select(f2($"x")),
Row(1.1) :: Row(0.0) :: Nil)
}

val f2 = udf((x: Double) => x, DoubleType)
checkAnswer(
Seq(java.lang.Double.valueOf(1.1), null).toDF("x").select(f2($"x")),
Row(1.1) :: Row(0.0) :: Nil)
}

test("use untyped Scala UDF should fail by default") {
val e = intercept[AnalysisException](udf((x: Int) => x, IntegerType))
assert(e.getMessage.contains("You're using untyped Scala UDF"))
}

test("SPARK-26308: udf with decimal") {
Expand Down