@@ -378,7 +378,24 @@ def test_udf2(self):
378378 [res ] = self .spark .sql ("SELECT strlen(a) FROM test WHERE strlen(a) > 1" ).collect ()
379379 self .assertEqual (4 , res [0 ])
380380
381- def test_non_deterministic_udf (self ):
381+ def test_udf3 (self ):
382+ twoargs = self .spark .catalog .registerFunction (
383+ "twoArgs" , UserDefinedFunction (lambda x , y : len (x ) + y ), IntegerType ())
384+ self .assertEqual (twoargs .deterministic , True )
385+ [row ] = self .spark .sql ("SELECT twoArgs('test', 1)" ).collect ()
386+ self .assertEqual (row [0 ], 5 )
387+
388+ def test_nondeterministic_udf (self ):
389+ from pyspark .sql .functions import udf
390+ import random
391+ udf_random_col = udf (lambda : int (100 * random .random ()), IntegerType ()).asNondeterministic ()
392+ self .assertEqual (udf_random_col .deterministic , False )
393+ df = self .spark .createDataFrame ([Row (1 )]).select (udf_random_col ().alias ('RAND' ))
394+ udf_add_ten = udf (lambda rand : rand + 10 , IntegerType ())
395+ [row ] = df .withColumn ('RAND_PLUS_TEN' , udf_add_ten ('RAND' )).collect ()
396+ self .assertEqual (row [0 ] + 10 , row [1 ])
397+
398+ def test_nondeterministic_udf2 (self ):
382399 import random
383400 from pyspark .sql .functions import udf
384401 random_udf = udf (lambda : random .randint (6 , 6 ), IntegerType ()).asNondeterministic ()
@@ -391,6 +408,7 @@ def test_non_deterministic_udf(self):
391408 self .assertEqual (row [0 ], "6" )
392409 [row ] = self .spark .range (1 ).select (random_udf ()).collect ()
393410 self .assertEqual (row [0 ], 6 )
411+ # render_doc() reproduces the help() exception without printing output
394412 pydoc .render_doc (udf (lambda : random .randint (6 , 6 ), IntegerType ()))
395413 pydoc .render_doc (random_udf )
396414 pydoc .render_doc (random_udf1 )
@@ -452,15 +470,6 @@ def test_udf_with_array_type(self):
452470 self .assertEqual (list (range (3 )), l1 )
453471 self .assertEqual (1 , l2 )
454472
455- def test_nondeterministic_udf (self ):
456- from pyspark .sql .functions import udf
457- import random
458- udf_random_col = udf (lambda : int (100 * random .random ()), IntegerType ()).asNondeterministic ()
459- df = self .spark .createDataFrame ([Row (1 )]).select (udf_random_col ().alias ('RAND' ))
460- udf_add_ten = udf (lambda rand : rand + 10 , IntegerType ())
461- [row ] = df .withColumn ('RAND_PLUS_TEN' , udf_add_ten ('RAND' )).collect ()
462- self .assertEqual (row [0 ] + 10 , row [1 ])
463-
464473 def test_broadcast_in_udf (self ):
465474 bar = {"a" : "aa" , "b" : "bb" , "c" : "abc" }
466475 foo = self .sc .broadcast (bar )
0 commit comments