Skip to content

Commit 09a1b89

Browse files
committed
fix.
1 parent 78e9b2c commit 09a1b89

File tree

3 files changed

+26
-18
lines changed

3 files changed

+26
-18
lines changed

python/pyspark/sql/catalog.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,15 @@ def dropGlobalTempView(self, viewName):
227227
@ignore_unicode_prefix
228228
@since(2.0)
229229
def registerFunction(self, name, f, returnType=StringType()):
230-
"""Registers a python function (including lambda function) as a UDF
230+
"""Registers a Python function (including lambda function) or a wrapped/native UDF
231231
so it can be used in SQL statements.
232232
233233
In addition to a name and the function itself, the return type can be optionally specified.
234234
When the return type is not given it default to a string and conversion will automatically
235235
be done. For any other return type, the produced object must match the specified type.
236236
237237
:param name: name of the UDF
238-
:param f: python function
238+
:param f: a Python function, or a wrapped/native UserDefinedFunction
239239
:param returnType: a :class:`pyspark.sql.types.DataType` object
240240
:return: a wrapped :class:`UserDefinedFunction`
241241
@@ -260,14 +260,14 @@ def registerFunction(self, name, f, returnType=StringType()):
260260
>>> from pyspark.sql.functions import udf
261261
>>> from pyspark.sql.types import IntegerType, StringType
262262
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
263-
>>> newRandom_udf = spark.catalog.registerFunction(
264-
... "random_udf", random_udf, StringType()) # doctest: +SKIP
263+
>>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType())
265264
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
266265
[Row(random_udf()=u'82')]
267266
>>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
268267
[Row(random_udf()=u'62')]
269268
"""
270269

270+
# This is to check whether the input function is a wrapped/native UserDefinedFunction
271271
if hasattr(f, 'asNondeterministic'):
272272
udf = UserDefinedFunction(f.func, returnType=returnType, name=name,
273273
evalType=PythonEvalType.SQL_BATCHED_UDF,

python/pyspark/sql/context.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,15 @@ def range(self, start, end=None, step=1, numPartitions=None):
175175
@ignore_unicode_prefix
176176
@since(1.2)
177177
def registerFunction(self, name, f, returnType=StringType()):
178-
"""Registers a python function (including lambda function) as a UDF
178+
"""Registers a Python function (including lambda function) or a wrapped/native UDF
179179
so it can be used in SQL statements.
180180
181181
In addition to a name and the function itself, the return type can be optionally specified.
182182
When the return type is not given it default to a string and conversion will automatically
183183
be done. For any other return type, the produced object must match the specified type.
184184
185185
:param name: name of the UDF
186-
:param f: python function
186+
:param f: a Python function, or a wrapped/native UserDefinedFunction
187187
:param returnType: a :class:`pyspark.sql.types.DataType` object
188188
:return: a wrapped :class:`UserDefinedFunction`
189189
@@ -208,8 +208,7 @@ def registerFunction(self, name, f, returnType=StringType()):
208208
>>> from pyspark.sql.functions import udf
209209
>>> from pyspark.sql.types import IntegerType, StringType
210210
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
211-
>>> newRandom_udf = sqlContext.registerFunction(
212-
... "random_udf", random_udf, StringType()) # doctest: +SKIP
211+
>>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType())
213212
>>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP
214213
[Row(random_udf()=u'82')]
215214
>>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP

python/pyspark/sql/tests.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,24 @@ def test_udf2(self):
378378
[res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
379379
self.assertEqual(4, res[0])
380380

381-
def test_non_deterministic_udf(self):
381+
def test_udf3(self):
382+
twoargs = self.spark.catalog.registerFunction(
383+
"twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType())
384+
self.assertEqual(twoargs.deterministic, True)
385+
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
386+
self.assertEqual(row[0], 5)
387+
388+
def test_nondeterministic_udf(self):
389+
from pyspark.sql.functions import udf
390+
import random
391+
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
392+
self.assertEqual(udf_random_col.deterministic, False)
393+
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
394+
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
395+
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
396+
self.assertEqual(row[0] + 10, row[1])
397+
398+
def test_nondeterministic_udf2(self):
382399
import random
383400
from pyspark.sql.functions import udf
384401
random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
@@ -391,6 +408,7 @@ def test_non_deterministic_udf(self):
391408
self.assertEqual(row[0], "6")
392409
[row] = self.spark.range(1).select(random_udf()).collect()
393410
self.assertEqual(row[0], 6)
411+
# render_doc() reproduces the help() exception without printing output
394412
pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
395413
pydoc.render_doc(random_udf)
396414
pydoc.render_doc(random_udf1)
@@ -452,15 +470,6 @@ def test_udf_with_array_type(self):
452470
self.assertEqual(list(range(3)), l1)
453471
self.assertEqual(1, l2)
454472

455-
def test_nondeterministic_udf(self):
456-
from pyspark.sql.functions import udf
457-
import random
458-
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
459-
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
460-
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
461-
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
462-
self.assertEqual(row[0] + 10, row[1])
463-
464473
def test_broadcast_in_udf(self):
465474
bar = {"a": "aa", "b": "bb", "c": "abc"}
466475
foo = self.sc.broadcast(bar)

0 commit comments

Comments
 (0)