@@ -386,6 +386,7 @@ def test_udf3(self):
386386 self .assertEqual (row [0 ], 5 )
387387
388388 def test_nondeterministic_udf (self ):
389+ # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
389390 from pyspark .sql .functions import udf
390391 import random
391392 udf_random_col = udf (lambda : int (100 * random .random ()), IntegerType ()).asNondeterministic ()
@@ -413,6 +414,18 @@ def test_nondeterministic_udf2(self):
413414 pydoc .render_doc (random_udf )
414415 pydoc .render_doc (random_udf1 )
415416
417+ def test_nondeterministic_udf_in_aggregate (self ):
418+ from pyspark .sql .functions import udf , sum
419+ import random
420+ udf_random_col = udf (lambda : int (100 * random .random ()), 'int' ).asNondeterministic ()
421+ df = self .spark .range (10 )
422+
423+ with QuietTest (self .sc ):
424+ with self .assertRaisesRegexp (AnalysisException , "nondeterministic" ):
425+ df .groupby ('id' ).agg (sum (udf_random_col ())).collect ()
426+ with self .assertRaisesRegexp (AnalysisException , "nondeterministic" ):
427+ df .agg (sum (udf_random_col ())).collect ()
428+
416429 def test_chained_udf (self ):
417430 self .spark .catalog .registerFunction ("double" , lambda x : x + x , IntegerType ())
418431 [row ] = self .spark .sql ("SELECT double(1)" ).collect ()
@@ -3567,6 +3580,18 @@ def tearDownClass(cls):
35673580 time .tzset ()
35683581 ReusedSQLTestCase .tearDownClass ()
35693582
3583+ @property
3584+ def random_udf (self ):
3585+ from pyspark .sql .functions import pandas_udf
3586+
3587+ @pandas_udf ('double' )
3588+ def random_udf (v ):
3589+ import pandas as pd
3590+ import numpy as np
3591+ return pd .Series (np .random .random (len (v )))
3592+ random_udf = random_udf .asNondeterministic ()
3593+ return random_udf
3594+
35703595 def test_vectorized_udf_basic (self ):
35713596 from pyspark .sql .functions import pandas_udf , col
35723597 df = self .spark .range (10 ).select (
@@ -3950,6 +3975,33 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self):
39503975 finally :
39513976 self .spark .conf .set ("spark.sql.session.timeZone" , orig_tz )
39523977
3978+ def test_nondeterministic_udf (self ):
3979+ # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
3980+ from pyspark .sql .functions import udf , pandas_udf , col
3981+
3982+ @pandas_udf ('double' )
3983+ def plus_ten (v ):
3984+ return v + 10
3985+ random_udf = self .random_udf
3986+
3987+ df = self .spark .range (10 ).withColumn ('rand' , random_udf (col ('id' )))
3988+ result1 = df .withColumn ('plus_ten(rand)' , plus_ten (df ['rand' ])).toPandas ()
3989+
3990+ self .assertEqual (random_udf .deterministic , False )
3991+ self .assertTrue (result1 ['plus_ten(rand)' ].equals (result1 ['rand' ] + 10 ))
3992+
3993+ def test_nondeterministic_udf_in_aggregate (self ):
3994+ from pyspark .sql .functions import pandas_udf , sum
3995+
3996+ df = self .spark .range (10 )
3997+ random_udf = self .random_udf
3998+
3999+ with QuietTest (self .sc ):
4000+ with self .assertRaisesRegexp (AnalysisException , 'nondeterministic' ):
4001+ df .groupby (df .id ).agg (sum (random_udf (df .id ))).collect ()
4002+ with self .assertRaisesRegexp (AnalysisException , 'nondeterministic' ):
4003+ df .agg (sum (random_udf (df .id ))).collect ()
4004+
39534005
39544006@unittest .skipIf (not _have_pandas or not _have_arrow , "Pandas or Arrow not installed" )
39554007class GroupbyApplyTests (ReusedSQLTestCase ):
0 commit comments