@@ -386,6 +386,8 @@ def test_udf3(self):
386386 self .assertEqual (row [0 ], 5 )
387387
388388 def test_nondeterministic_udf (self ):
389+ # Test that the result of nondeterministic UDFs are evaluated only once in
390+ # chained UDF evaluations
389391 from pyspark .sql .functions import udf
390392 import random
391393 udf_random_col = udf (lambda : int (100 * random .random ()), IntegerType ()).asNondeterministic ()
@@ -3975,19 +3977,21 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self):
39753977 self .spark .conf .set ("spark.sql.session.timeZone" , orig_tz )
39763978
39773979 def test_nondeterministic_udf (self ):
3978- # Non-deterministic UDFs should be allowed in select and withColumn
3979- from pyspark .sql .functions import pandas_udf , col
3980+ # Test that the result of nondeterministic UDFs are evaluated only once in
3981+ # chained UDF evaluations
3982+ from pandas .testing import assert_series_equal
3983+ from pyspark .sql .functions import udf , pandas_udf , col
39803984
3985+ @pandas_udf ('double' )
3986+ def plus_ten (v ):
3987+ return v + 10
39813988 random_udf = self .random_udf
3982- df = self .spark .range (10 )
39833989
3984- result1 = df . select ( random_udf (col ('id' )). alias ( 'rand' )). collect ( )
3985- result2 = df .withColumn ('rand' , random_udf ( col ( 'id' ))). collect ()
3990+ df = self . spark . range ( 10 ). withColumn ( 'rand' , random_udf (col ('id' )))
3991+ result1 = df .withColumn ('plus_ten( rand) ' , plus_ten ( df [ 'rand' ])). toPandas ()
39863992
3987- for row in result1 :
3988- self .assertTrue (0.0 <= row .rand < 1.0 )
3989- for row in result2 :
3990- self .assertTrue (0.0 <= row .rand < 1.0 )
3993+ self .assertEqual (random_udf .deterministic , False )
3994+ assert_series_equal (result1 ['plus_ten(rand)' ], result1 ['rand' ] + 10 , check_names = False )
39913995
39923996 def test_nondeterministic_udf_in_aggregate (self ):
39933997 from pyspark .sql .functions import pandas_udf , sum
0 commit comments