Skip to content

Commit 0d8d943

Browse files
committed
Fix test_nondeterministic_udf
1 parent 46c6ad7 commit 0d8d943

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

python/pyspark/sql/tests.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,8 @@ def test_udf3(self):
386386
self.assertEqual(row[0], 5)
387387

388388
def test_nondeterministic_udf(self):
389+
# Test that the result of nondeterministic UDFs are evaluated only once in
390+
# chained UDF evaluations
389391
from pyspark.sql.functions import udf
390392
import random
391393
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
@@ -3975,19 +3977,21 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self):
39753977
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
39763978

39773979
def test_nondeterministic_udf(self):
3978-
# Non-deterministic UDFs should be allowed in select and withColumn
3979-
from pyspark.sql.functions import pandas_udf, col
3980+
# Test that the result of nondeterministic UDFs are evaluated only once in
3981+
# chained UDF evaluations
3982+
from pandas.testing import assert_series_equal
3983+
from pyspark.sql.functions import udf, pandas_udf, col
39803984

3985+
@pandas_udf('double')
3986+
def plus_ten(v):
3987+
return v + 10
39813988
random_udf = self.random_udf
3982-
df = self.spark.range(10)
39833989

3984-
result1 = df.select(random_udf(col('id')).alias('rand')).collect()
3985-
result2 = df.withColumn('rand', random_udf(col('id'))).collect()
3990+
df = self.spark.range(10).withColumn('rand', random_udf(col('id')))
3991+
result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas()
39863992

3987-
for row in result1:
3988-
self.assertTrue(0.0 <= row.rand < 1.0)
3989-
for row in result2:
3990-
self.assertTrue(0.0 <= row.rand < 1.0)
3993+
self.assertEqual(random_udf.deterministic, False)
3994+
assert_series_equal(result1['plus_ten(rand)'], result1['rand'] + 10, check_names=False)
39913995

39923996
def test_nondeterministic_udf_in_aggregate(self):
39933997
from pyspark.sql.functions import pandas_udf, sum

0 commit comments

Comments
 (0)