Skip to content

Commit 1567b6c

Browse files
committed
0-args Java UDF should not be called only once
1 parent 7021588 commit 1567b6c

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
142142
val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]"
143143
val anyParams = (1 to i).map(_ => "_: Any").mkString(", ")
144144
val version = if (i == 0) "2.3.0" else "1.3.0"
145-
val funcCall = if (i == 0) "() => func" else "func"
145+
val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)"
146146
println(s"""
147147
|/**
148148
| * Register a deterministic Java UDF$i instance as user-defined function (UDF).
149149
| * @since $version
150150
| */
151151
|def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = {
152-
| val func = f$anyCast.call($anyParams)
152+
| val func = $funcCall
153153
| def builder(e: Seq[Expression]) = if (e.length == $i) {
154-
| ScalaUDF($funcCall, returnType, e, e.map(_ => false), udfName = Some(name))
154+
| ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
155155
| } else {
156156
| throw new AnalysisException("Invalid number of arguments for function " + name +
157157
| ". Expected: $i; Found: " + e.length)
@@ -717,9 +717,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
717717
* @since 2.3.0
718718
*/
719719
def register(name: String, f: UDF0[_], returnType: DataType): Unit = {
720-
val func = f.asInstanceOf[UDF0[Any]].call()
720+
val func = () => f.asInstanceOf[UDF0[Any]].call()
721721
def builder(e: Seq[Expression]) = if (e.length == 0) {
722-
ScalaUDF(() => func, returnType, e, e.map(_ => false), udfName = Some(name))
722+
ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
723723
} else {
724724
throw new AnalysisException("Invalid number of arguments for function " + name +
725725
". Expected: 0; Found: " + e.length)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3932,7 +3932,7 @@ object functions {
39323932
val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ")
39333933
val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]"
39343934
val anyParams = (1 to i).map(_ => "_: Any").mkString(", ")
3935-
val funcCall = if (i == 0) "() => func" else "func"
3935+
val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)"
39363936
println(s"""
39373937
|/**
39383938
| * Defines a Java UDF$i instance as user-defined function (UDF).
@@ -3944,8 +3944,8 @@ object functions {
39443944
| * @since 2.3.0
39453945
| */
39463946
|def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = {
3947-
| val func = f$anyCast.call($anyParams)
3948-
| SparkUserDefinedFunction($funcCall, returnType, inputSchemas = Seq.fill($i)(None))
3947+
| val func = $funcCall
3948+
| SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill($i)(None))
39493949
|}""".stripMargin)
39503950
}
39513951
@@ -4145,8 +4145,8 @@ object functions {
41454145
* @since 2.3.0
41464146
*/
41474147
def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = {
4148-
val func = f.asInstanceOf[UDF0[Any]].call()
4149-
SparkUserDefinedFunction(() => func, returnType, inputSchemas = Seq.fill(0)(None))
4148+
val func = () => f.asInstanceOf[UDF0[Any]].call()
4149+
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(0)(None))
41504150
}
41514151

41524152
/**

sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,4 +514,13 @@ class UDFSuite extends QueryTest with SharedSQLContext {
514514
assert(df.collect().toSeq === Seq(Row(expected)))
515515
}
516516
}
517+
518+
test("SPARK-28321 0-args Java UDF should not be called only once") {
519+
val nonDeterministicJavaUDF = udf(
520+
new UDF0[Int] {
521+
override def call(): Int = scala.util.Random.nextInt()
522+
}, IntegerType).asNondeterministic()
523+
524+
assert(spark.range(2).select(nonDeterministicJavaUDF()).distinct().count() == 2)
525+
}
517526
}

0 commit comments

Comments
 (0)