Skip to content

Commit ff48b1b

Browse files
mgaido91gatorsmile
authored andcommitted
[SPARK-22901][PYTHON] Add deterministic flag to pyspark UDF
## What changes were proposed in this pull request? In SPARK-20586 the flag `deterministic` was added to Scala UDF, but it is not available for python UDF. This flag is useful for cases when the UDF's code can return different result with the same input. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. This can lead to unexpected behavior. This PR adds the deterministic flag, via the `asNondeterministic` method, to let the user mark the function as non-deterministic and therefore avoid the optimizations which might lead to strange behaviors. ## How was this patch tested? Manual tests: ``` >>> from pyspark.sql.functions import * >>> from pyspark.sql.types import * >>> df_br = spark.createDataFrame([{'name': 'hello'}]) >>> import random >>> udf_random_col = udf(lambda: int(100*random.random()), IntegerType()).asNondeterministic() >>> df_br = df_br.withColumn('RAND', udf_random_col()) >>> random.seed(1234) >>> udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) >>> df_br.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).show() +-----+----+-------------+ | name|RAND|RAND_PLUS_TEN| +-----+----+-------------+ |hello| 3| 13| +-----+----+-------------+ ``` Author: Marco Gaido <[email protected]> Author: Marco Gaido <[email protected]> Closes #19929 from mgaido91/SPARK-22629.
1 parent eb386be commit ff48b1b

File tree

8 files changed

+48
-10
lines changed

8 files changed

+48
-10
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ private[spark] object PythonEvalType {
3939

4040
val SQL_PANDAS_SCALAR_UDF = 200
4141
val SQL_PANDAS_GROUP_MAP_UDF = 201
42+
43+
def toString(pythonEvalType: Int): String = pythonEvalType match {
44+
case NON_UDF => "NON_UDF"
45+
case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
46+
case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF"
47+
case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF"
48+
}
4249
}
4350

4451
/**

python/pyspark/sql/functions.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2093,9 +2093,14 @@ class PandasUDFType(object):
20932093
def udf(f=None, returnType=StringType()):
20942094
"""Creates a user defined function (UDF).
20952095
2096-
.. note:: The user-defined functions must be deterministic. Due to optimization,
2097-
duplicate invocations may be eliminated or the function may even be invoked more times than
2098-
it is present in the query.
2096+
.. note:: The user-defined functions are considered deterministic by default. Due to
2097+
optimization, duplicate invocations may be eliminated or the function may even be invoked
2098+
more times than it is present in the query. If your function is not deterministic, call
2099+
`asNondeterministic` on the user defined function. E.g.:
2100+
2101+
>>> from pyspark.sql.types import IntegerType
2102+
>>> import random
2103+
>>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()
20992104
21002105
.. note:: The user-defined functions do not support conditional expressions or short curcuiting
21012106
in boolean expressions and it ends up with being executed all internally. If the functions

python/pyspark/sql/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,15 @@ def test_udf_with_array_type(self):
435435
self.assertEqual(list(range(3)), l1)
436436
self.assertEqual(1, l2)
437437

438+
def test_nondeterministic_udf(self):
439+
from pyspark.sql.functions import udf
440+
import random
441+
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
442+
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
443+
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
444+
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
445+
self.assertEqual(row[0] + 10, row[1])
446+
438447
def test_broadcast_in_udf(self):
439448
bar = {"a": "aa", "b": "bb", "c": "abc"}
440449
foo = self.sc.broadcast(bar)

python/pyspark/sql/udf.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(self, func,
9292
func.__name__ if hasattr(func, '__name__')
9393
else func.__class__.__name__)
9494
self.evalType = evalType
95+
self._deterministic = True
9596

9697
@property
9798
def returnType(self):
@@ -129,7 +130,7 @@ def _create_judf(self):
129130
wrapped_func = _wrap_function(sc, self.func, self.returnType)
130131
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
131132
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
132-
self._name, wrapped_func, jdt, self.evalType)
133+
self._name, wrapped_func, jdt, self.evalType, self._deterministic)
133134
return judf
134135

135136
def __call__(self, *cols):
@@ -161,5 +162,15 @@ def wrapper(*args):
161162
wrapper.func = self.func
162163
wrapper.returnType = self.returnType
163164
wrapper.evalType = self.evalType
165+
wrapper.asNondeterministic = self.asNondeterministic
164166

165167
return wrapper
168+
169+
def asNondeterministic(self):
170+
"""
171+
Updates UserDefinedFunction to nondeterministic.
172+
173+
.. versionadded:: 2.3
174+
"""
175+
self._deterministic = False
176+
return self

sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.TypeTag
2323
import scala.util.Try
2424

2525
import org.apache.spark.annotation.InterfaceStability
26+
import org.apache.spark.api.python.PythonEvalType
2627
import org.apache.spark.internal.Logging
2728
import org.apache.spark.sql.api.java._
2829
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
@@ -41,8 +42,6 @@ import org.apache.spark.util.Utils
4142
* spark.udf
4243
* }}}
4344
*
44-
* @note The user-defined functions must be deterministic.
45-
*
4645
* @since 1.3.0
4746
*/
4847
@InterfaceStability.Stable
@@ -58,6 +57,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
5857
| pythonIncludes: ${udf.func.pythonIncludes}
5958
| pythonExec: ${udf.func.pythonExec}
6059
| dataType: ${udf.dataType}
60+
| pythonEvalType: ${PythonEvalType.toString(udf.pythonEvalType)}
61+
| udfDeterministic: ${udf.udfDeterministic}
6162
""".stripMargin)
6263

6364
functionRegistry.createOrReplaceTempFunction(name, udf.builder)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ case class PythonUDF(
2929
func: PythonFunction,
3030
dataType: DataType,
3131
children: Seq[Expression],
32-
evalType: Int)
32+
evalType: Int,
33+
udfDeterministic: Boolean)
3334
extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {
3435

36+
override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
37+
3538
override def toString: String = s"$name(${children.mkString(", ")})"
3639

3740
override def nullable: Boolean = true

sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ case class UserDefinedPythonFunction(
2929
name: String,
3030
func: PythonFunction,
3131
dataType: DataType,
32-
pythonEvalType: Int) {
32+
pythonEvalType: Int,
33+
udfDeterministic: Boolean) {
3334

3435
def builder(e: Seq[Expression]): PythonUDF = {
35-
PythonUDF(name, func, dataType, e, pythonEvalType)
36+
PythonUDF(name, func, dataType, e, pythonEvalType, udfDeterministic)
3637
}
3738

3839
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */

sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,5 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction(
109109
name = "dummyUDF",
110110
func = new DummyUDF,
111111
dataType = BooleanType,
112-
pythonEvalType = PythonEvalType.SQL_BATCHED_UDF)
112+
pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
113+
udfDeterministic = true)

0 commit comments

Comments
 (0)