Skip to content

Commit 1f5e354

Browse files
committed
[SPARK-22939][PYSPARK] Support Spark UDF in registerFunction
## What changes were proposed in this pull request? ```Python import random from pyspark.sql.functions import udf from pyspark.sql.types import IntegerType, StringType random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() spark.catalog.registerFunction("random_udf", random_udf, StringType()) spark.sql("SELECT random_udf()").collect() ``` We will get the following error. ``` Py4JError: An error occurred while calling o29.__getnewargs__. Trace: py4j.Py4JException: Method __getnewargs__([]) does not exist at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318) at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326) at py4j.Gateway.invoke(Gateway.java:274) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:214) at java.lang.Thread.run(Thread.java:745) ``` This PR is to support it. ## How was this patch tested? WIP Author: gatorsmile <[email protected]> Closes #20137 from gatorsmile/registerFunction. (cherry picked from commit 5aadbc9) Signed-off-by: gatorsmile <[email protected]>
1 parent eb99b8a commit 1f5e354

File tree

4 files changed

+84
-29
lines changed

4 files changed

+84
-29
lines changed

python/pyspark/sql/catalog.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,15 @@ def dropGlobalTempView(self, viewName):
227227
@ignore_unicode_prefix
228228
@since(2.0)
229229
def registerFunction(self, name, f, returnType=StringType()):
230-
"""Registers a python function (including lambda function) as a UDF
231-
so it can be used in SQL statements.
230+
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
231+
as a UDF. The registered UDF can be used in SQL statement.
232232
233233
In addition to a name and the function itself, the return type can be optionally specified.
234234
When the return type is not given it default to a string and conversion will automatically
235235
be done. For any other return type, the produced object must match the specified type.
236236
237237
:param name: name of the UDF
238-
:param f: python function
238+
:param f: a Python function, or a wrapped/native UserDefinedFunction
239239
:param returnType: a :class:`pyspark.sql.types.DataType` object
240240
:return: a wrapped :class:`UserDefinedFunction`
241241
@@ -255,9 +255,26 @@ def registerFunction(self, name, f, returnType=StringType()):
255255
>>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
256256
>>> spark.sql("SELECT stringLengthInt('test')").collect()
257257
[Row(stringLengthInt(test)=4)]
258+
259+
>>> import random
260+
>>> from pyspark.sql.functions import udf
261+
>>> from pyspark.sql.types import IntegerType, StringType
262+
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
263+
>>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType())
264+
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
265+
[Row(random_udf()=u'82')]
266+
>>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
267+
[Row(random_udf()=u'62')]
258268
"""
259-
udf = UserDefinedFunction(f, returnType=returnType, name=name,
260-
evalType=PythonEvalType.SQL_BATCHED_UDF)
269+
270+
# This is to check whether the input function is a wrapped/native UserDefinedFunction
271+
if hasattr(f, 'asNondeterministic'):
272+
udf = UserDefinedFunction(f.func, returnType=returnType, name=name,
273+
evalType=PythonEvalType.SQL_BATCHED_UDF,
274+
deterministic=f.deterministic)
275+
else:
276+
udf = UserDefinedFunction(f, returnType=returnType, name=name,
277+
evalType=PythonEvalType.SQL_BATCHED_UDF)
261278
self._jsparkSession.udf().registerPython(name, udf._judf)
262279
return udf._wrapped()
263280

python/pyspark/sql/context.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,15 @@ def range(self, start, end=None, step=1, numPartitions=None):
175175
@ignore_unicode_prefix
176176
@since(1.2)
177177
def registerFunction(self, name, f, returnType=StringType()):
178-
"""Registers a python function (including lambda function) as a UDF
179-
so it can be used in SQL statements.
178+
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
179+
as a UDF. The registered UDF can be used in SQL statement.
180180
181181
In addition to a name and the function itself, the return type can be optionally specified.
182182
When the return type is not given it default to a string and conversion will automatically
183183
be done. For any other return type, the produced object must match the specified type.
184184
185185
:param name: name of the UDF
186-
:param f: python function
186+
:param f: a Python function, or a wrapped/native UserDefinedFunction
187187
:param returnType: a :class:`pyspark.sql.types.DataType` object
188188
:return: a wrapped :class:`UserDefinedFunction`
189189
@@ -203,6 +203,16 @@ def registerFunction(self, name, f, returnType=StringType()):
203203
>>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
204204
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
205205
[Row(stringLengthInt(test)=4)]
206+
207+
>>> import random
208+
>>> from pyspark.sql.functions import udf
209+
>>> from pyspark.sql.types import IntegerType, StringType
210+
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
211+
>>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType())
212+
>>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP
213+
[Row(random_udf()=u'82')]
214+
>>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
215+
[Row(random_udf()=u'62')]
206216
"""
207217
return self.sparkSession.catalog.registerFunction(name, f, returnType)
208218

python/pyspark/sql/tests.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,41 @@ def test_udf2(self):
378378
[res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
379379
self.assertEqual(4, res[0])
380380

381+
def test_udf3(self):
382+
twoargs = self.spark.catalog.registerFunction(
383+
"twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType())
384+
self.assertEqual(twoargs.deterministic, True)
385+
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
386+
self.assertEqual(row[0], 5)
387+
388+
def test_nondeterministic_udf(self):
389+
from pyspark.sql.functions import udf
390+
import random
391+
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
392+
self.assertEqual(udf_random_col.deterministic, False)
393+
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
394+
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
395+
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
396+
self.assertEqual(row[0] + 10, row[1])
397+
398+
def test_nondeterministic_udf2(self):
399+
import random
400+
from pyspark.sql.functions import udf
401+
random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
402+
self.assertEqual(random_udf.deterministic, False)
403+
random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType())
404+
self.assertEqual(random_udf1.deterministic, False)
405+
[row] = self.spark.sql("SELECT randInt()").collect()
406+
self.assertEqual(row[0], "6")
407+
[row] = self.spark.range(1).select(random_udf1()).collect()
408+
self.assertEqual(row[0], "6")
409+
[row] = self.spark.range(1).select(random_udf()).collect()
410+
self.assertEqual(row[0], 6)
411+
# render_doc() reproduces the help() exception without printing output
412+
pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
413+
pydoc.render_doc(random_udf)
414+
pydoc.render_doc(random_udf1)
415+
381416
def test_chained_udf(self):
382417
self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
383418
[row] = self.spark.sql("SELECT double(1)").collect()
@@ -435,15 +470,6 @@ def test_udf_with_array_type(self):
435470
self.assertEqual(list(range(3)), l1)
436471
self.assertEqual(1, l2)
437472

438-
def test_nondeterministic_udf(self):
439-
from pyspark.sql.functions import udf
440-
import random
441-
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
442-
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
443-
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
444-
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
445-
self.assertEqual(row[0] + 10, row[1])
446-
447473
def test_broadcast_in_udf(self):
448474
bar = {"a": "aa", "b": "bb", "c": "abc"}
449475
foo = self.sc.broadcast(bar)
@@ -567,15 +593,13 @@ def test_read_multiple_orc_file(self):
567593

568594
def test_udf_with_input_file_name(self):
569595
from pyspark.sql.functions import udf, input_file_name
570-
from pyspark.sql.types import StringType
571596
sourceFile = udf(lambda path: path, StringType())
572597
filePath = "python/test_support/sql/people1.json"
573598
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
574599
self.assertTrue(row[0].find("people1.json") != -1)
575600

576601
def test_udf_with_input_file_name_for_hadooprdd(self):
577602
from pyspark.sql.functions import udf, input_file_name
578-
from pyspark.sql.types import StringType
579603

580604
def filename(path):
581605
return path
@@ -635,7 +659,6 @@ def test_udf_with_string_return_type(self):
635659

636660
def test_udf_shouldnt_accept_noncallable_object(self):
637661
from pyspark.sql.functions import UserDefinedFunction
638-
from pyspark.sql.types import StringType
639662

640663
non_callable = None
641664
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
@@ -1299,7 +1322,6 @@ def test_between_function(self):
12991322
df.filter(df.a.between(df.b, df.c)).collect())
13001323

13011324
def test_struct_type(self):
1302-
from pyspark.sql.types import StructType, StringType, StructField
13031325
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
13041326
struct2 = StructType([StructField("f1", StringType(), True),
13051327
StructField("f2", StringType(), True, None)])
@@ -1368,7 +1390,6 @@ def test_parse_datatype_string(self):
13681390
_parse_datatype_string("a INT, c DOUBLE"))
13691391

13701392
def test_metadata_null(self):
1371-
from pyspark.sql.types import StructType, StringType, StructField
13721393
schema = StructType([StructField("f1", StringType(), True, None),
13731394
StructField("f2", StringType(), True, {'a': None})])
13741395
rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])

python/pyspark/sql/udf.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType):
5656
)
5757

5858
# Set the name of the UserDefinedFunction object to be the name of function f
59-
udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType)
59+
udf_obj = UserDefinedFunction(
60+
f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
6061
return udf_obj._wrapped()
6162

6263

@@ -67,8 +68,10 @@ class UserDefinedFunction(object):
6768
.. versionadded:: 1.3
6869
"""
6970
def __init__(self, func,
70-
returnType=StringType(), name=None,
71-
evalType=PythonEvalType.SQL_BATCHED_UDF):
71+
returnType=StringType(),
72+
name=None,
73+
evalType=PythonEvalType.SQL_BATCHED_UDF,
74+
deterministic=True):
7275
if not callable(func):
7376
raise TypeError(
7477
"Invalid function: not a function or callable (__call__ is not defined): "
@@ -92,7 +95,7 @@ def __init__(self, func,
9295
func.__name__ if hasattr(func, '__name__')
9396
else func.__class__.__name__)
9497
self.evalType = evalType
95-
self._deterministic = True
98+
self.deterministic = deterministic
9699

97100
@property
98101
def returnType(self):
@@ -130,14 +133,17 @@ def _create_judf(self):
130133
wrapped_func = _wrap_function(sc, self.func, self.returnType)
131134
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
132135
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
133-
self._name, wrapped_func, jdt, self.evalType, self._deterministic)
136+
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
134137
return judf
135138

136139
def __call__(self, *cols):
137140
judf = self._judf
138141
sc = SparkContext._active_spark_context
139142
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
140143

144+
# This function is for improving the online help system in the interactive interpreter.
145+
# For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and
146+
# argument annotation. (See: SPARK-19161)
141147
def _wrapped(self):
142148
"""
143149
Wrap this udf with a function and attach docstring from func
@@ -162,7 +168,8 @@ def wrapper(*args):
162168
wrapper.func = self.func
163169
wrapper.returnType = self.returnType
164170
wrapper.evalType = self.evalType
165-
wrapper.asNondeterministic = self.asNondeterministic
171+
wrapper.deterministic = self.deterministic
172+
wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped()
166173

167174
return wrapper
168175

@@ -172,5 +179,5 @@ def asNondeterministic(self):
172179
173180
.. versionadded:: 2.3
174181
"""
175-
self._deterministic = False
182+
self.deterministic = False
176183
return self

0 commit comments

Comments
 (0)