Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package org.apache.spark.sql.catalyst

import java.lang.reflect.Constructor

import scala.util.Properties

import org.apache.commons.lang3.reflect.ConstructorUtils

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
Expand Down Expand Up @@ -879,7 +882,7 @@ object ScalaReflection extends ScalaReflection {
* Support for generating catalyst schemas for scala objects. Note that unlike its companion
* object, this trait able to work in both the runtime and the compile time (macro) universe.
*/
trait ScalaReflection {
trait ScalaReflection extends Logging {
/** The universe we work in (runtime or macro) */
val universe: scala.reflect.api.Universe

Expand Down Expand Up @@ -932,6 +935,23 @@ trait ScalaReflection {
tpe.dealias.erasure.typeSymbol.asClass.fullName
}

/**
* Returns the nullability of the input parameter types of the scala function object.
*
* Note that this only works with Scala 2.11, and the information returned may be inaccurate if
* used with a different Scala version.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we explicitly return seq of true if it's not scala 2.11? Then the behavior is more predictable than may be inaccurate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it simply returns if going through the below code path. I should probably make the java doc clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argument here is it's not necessarily wrong if using scala 2.12. if all inputs are of boxed types, then it can still be good. I think it's just enough to say "we don't support it. switch to the new interface otherwise we can't guarantee correctness."

*/
def getParameterTypeNullability(func: AnyRef): Seq[Boolean] = {
if (!Properties.versionString.contains("2.11")) {
logWarning(s"Scala ${Properties.versionString} cannot get type nullability correctly via " +
"reflection, thus Spark cannot add proper input null check for UDF. To avoid this " +
"problem, use the typed UDF interfaces instead.")
}
val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge)
assert(methods.length == 1)
methods.head.getParameterTypes.map(!_.isPrimitive)
}

/**
* Returns the parameter names and types for the primary constructor of this type.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2137,36 +2137,27 @@ class Analyzer(

case p => p transformExpressionsUp {

case udf@ScalaUDF(func, _, inputs, _, _, _, _, nullableTypes) =>
if (nullableTypes.isEmpty) {
// If no nullability info is available, do nothing. No fields will be specially
// checked for null in the plan. If nullability info is incorrect, the results
// of the UDF could be wrong.
udf
} else {
// Otherwise, add special handling of null for fields that can't accept null.
// The result of operations like this, when passed null, is generally to return null.
assert(nullableTypes.length == inputs.length)

// TODO: skip null handling for not-nullable primitive inputs after we can completely
// trust the `nullable` information.
val needsNullCheck = (nullable: Boolean, expr: Expression) =>
nullable && !expr.isInstanceOf[KnownNotNull]
val inputsNullCheck = nullableTypes.zip(inputs)
.filter { case (nullableType, expr) => needsNullCheck(!nullableType, expr) }
.map { case (_, expr) => IsNull(expr) }
.reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
// Once we add an `If` check above the udf, it is safe to mark those checked inputs
// as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning
// branch of `If` will be called if any of these checked inputs is null. Thus we can
// prevent this rule from being applied repeatedly.
val newInputs = nullableTypes.zip(inputs).map { case (nullable, expr) =>
if (nullable) expr else KnownNotNull(expr)
}
inputsNullCheck
.map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs)))
.getOrElse(udf)
}
case udf @ ScalaUDF(_, _, inputs, inputsNullSafe, _, _, _, _)
if inputsNullSafe.contains(false) =>
// Otherwise, add special handling of null for fields that can't accept null.
// The result of operations like this, when passed null, is generally to return null.
assert(inputsNullSafe.length == inputs.length)

// TODO: skip null handling for not-nullable primitive inputs after we can completely
// trust the `nullable` information.
val inputsNullCheck = inputsNullSafe.zip(inputs)
.filter { case (nullSafe, _) => !nullSafe }
.map { case (_, expr) => IsNull(expr) }
.reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
// Once we add an `If` check above the udf, it is safe to mark those checked inputs
// as null-safe (i.e., set `inputsNullSafe` all `true`), because the null-returning
// branch of `If` will be called if any of these checked inputs is null. Thus we can
// prevent this rule from being applied repeatedly.
val newInputsNullSafe = inputsNullSafe.map(_ => true)
inputsNullCheck
.map(If(_, Literal.create(null, udf.dataType),
udf.copy(inputsNullSafe = newInputsNullSafe)))
.getOrElse(udf)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.DataType
Expand All @@ -31,6 +31,9 @@ import org.apache.spark.sql.types.DataType
* null. Use boxed type or [[Option]] if you wanna do the null-handling yourself.
* @param dataType Return type of function.
* @param children The input expressions of this UDF.
* @param inputsNullSafe Whether the inputs are of non-primitive types or not nullable. Null values
* of Scala primitive types will be converted to the type's default value and
* lead to wrong results, thus need special handling before calling the UDF.
* @param inputTypes The expected input types of this UDF, used to perform type coercion. If we do
* not want to perform coercion, simply use "Nil". Note that it would've been
* better to use Option of Seq[DataType] so we can use "None" as the case for no
Expand All @@ -39,17 +42,16 @@ import org.apache.spark.sql.types.DataType
* @param nullable True if the UDF can return null value.
* @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result
* each time it is invoked with a particular input.
* @param nullableTypes which of the inputTypes are nullable (i.e. not primitive)
*/
case class ScalaUDF(
function: AnyRef,
dataType: DataType,
children: Seq[Expression],
inputsNullSafe: Seq[Boolean],
inputTypes: Seq[DataType] = Nil,
udfName: Option[String] = None,
nullable: Boolean = true,
udfDeterministic: Boolean = true,
nullableTypes: Seq[Boolean] = Nil)
udfDeterministic: Boolean = true)
extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {

// The constructor for SPARK 2.1 and 2.2
Expand All @@ -60,8 +62,8 @@ case class ScalaUDF(
inputTypes: Seq[DataType],
udfName: Option[String]) = {
this(
function, dataType, children, inputTypes, udfName, nullable = true,
udfDeterministic = true, nullableTypes = Nil)
function, dataType, children, ScalaReflection.getParameterTypeNullability(function),
inputTypes, udfName, nullable = true, udfDeterministic = true)
}

override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,24 +314,24 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}

// non-primitive parameters do not need special null handling
val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil)
val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, true :: Nil)
val expected1 = udf1
checkUDF(udf1, expected1)

// only primitive parameter needs special null handling
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil,
nullableTypes = true :: false :: Nil)
true :: false :: Nil)
val expected2 =
If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil))
If(IsNull(double), nullResult, udf2.copy(inputsNullSafe = true :: true :: Nil))
checkUDF(udf2, expected2)

// special null handling should apply to all primitive parameters
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil,
nullableTypes = false :: false :: Nil)
false :: false :: Nil)
val expected3 = If(
IsNull(short) || IsNull(double),
nullResult,
udf3.copy(children = KnownNotNull(short) :: KnownNotNull(double) :: Nil))
udf3.copy(inputsNullSafe = true :: true :: Nil))
checkUDF(udf3, expected3)

// we can skip special null handling for primitive parameters that are not nullable
Expand All @@ -340,19 +340,19 @@ class AnalysisSuite extends AnalysisTest with Matchers {
(s: Short, d: Double) => "x",
StringType,
short :: double.withNullability(false) :: Nil,
nullableTypes = false :: false :: Nil)
false :: false :: Nil)
val expected4 = If(
IsNull(short),
nullResult,
udf4.copy(children = KnownNotNull(short) :: double.withNullability(false) :: Nil))
udf4.copy(inputsNullSafe = true :: true :: Nil))
// checkUDF(udf4, expected4)
}

test("SPARK-24891 Fix HandleNullInputsForUDF rule") {
val a = testRelation.output(0)
val func = (x: Int, y: Int) => x + y
val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil, nullableTypes = false :: false :: Nil)
val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil, nullableTypes = false :: false :: Nil)
val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil, false :: false :: Nil)
val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil, false :: false :: Nil)
val plan = Project(Alias(udf2, "")() :: Nil, testRelation)
comparePlans(plan.analyze, plan.analyze.analyze)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,19 @@ import org.apache.spark.sql.types.{IntegerType, StringType}
class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

test("basic") {
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, true :: Nil)
checkEvaluation(intUdf, 2)

val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil)
checkEvaluation(stringUdf, "ax")
}

test("better error message for NPE") {
val udf = ScalaUDF(
(s: String) => s.toLowerCase(Locale.ROOT),
StringType,
Literal.create(null, StringType) :: Nil)
Literal.create(null, StringType) :: Nil,
true :: Nil)

val e1 = intercept[SparkException](udf.eval())
assert(e1.getMessage.contains("Failed to execute user defined function"))
Expand All @@ -50,7 +51,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

test("SPARK-22695: ScalaUDF should not use global variables") {
val ctx = new CodegenContext
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx)
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil).genCode(ctx)
assert(ctx.inlinedMutableStates.isEmpty)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ class TreeNodeSuite extends SparkFunSuite {
}

test("toJSON should not throws java.lang.StackOverflowError") {
val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr))
val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr), true :: Nil)
// Should not throw java.lang.StackOverflowError
udf.toJSON
}
Expand Down
Loading