-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-22939] [PySpark] Support Spark UDF in registerFunction #20137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
8216b6b
0ecdf63
e8d0a4c
35e6a4a
b89b720
3208136
f099261
d1ba703
6ac25e6
85f11bf
78e9b2c
09a1b89
2482e6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -255,9 +255,26 @@ def registerFunction(self, name, f, returnType=StringType()): | |
| >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) | ||
| >>> spark.sql("SELECT stringLengthInt('test')").collect() | ||
| [Row(stringLengthInt(test)=4)] | ||
|
|
||
| >>> import random | ||
| >>> from pyspark.sql.functions import udf | ||
| >>> from pyspark.sql.types import IntegerType, StringType | ||
| >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() | ||
| >>> newRandom_udf = spark.catalog.registerFunction( | ||
| ... "random_udf", random_udf, StringType()) # doctest: +SKIP | ||
|
||
| >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP | ||
| [Row(random_udf()=u'82')] | ||
| >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP | ||
| [Row(random_udf()=u'62')] | ||
| """ | ||
| udf = UserDefinedFunction(f, returnType=returnType, name=name, | ||
| evalType=PythonEvalType.SQL_BATCHED_UDF) | ||
|
|
||
| if hasattr(f, 'asNondeterministic'): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, this one made me to suggest So, here this can be wrapped function or Could we at least leave come comments saying that this can be both wrapped function for
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will add a comment. |
||
| udf = UserDefinedFunction(f.func, returnType=returnType, name=name, | ||
| evalType=PythonEvalType.SQL_BATCHED_UDF, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @ueshin @icexelloss , shall we support register pandas UDF here too?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems we can support it by just changing
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 but I think there's no way to use a group map UDF in SQL syntax if I understood correctly. I think we can safely fail fast for now as well.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SGTM
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will support the pandas UDF as a separate PR.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 too |
||
| deterministic=f.deterministic) | ||
| else: | ||
| udf = UserDefinedFunction(f, returnType=returnType, name=name, | ||
| evalType=PythonEvalType.SQL_BATCHED_UDF) | ||
| self._jsparkSession.udf().registerPython(name, udf._judf) | ||
| return udf._wrapped() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -378,6 +378,23 @@ def test_udf2(self): | |
| [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() | ||
| self.assertEqual(4, res[0]) | ||
|
|
||
| def test_non_deterministic_udf(self): | ||
| import random | ||
| from pyspark.sql.functions import udf | ||
| random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() | ||
| self.assertEqual(random_udf.deterministic, False) | ||
| random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType()) | ||
| self.assertEqual(random_udf1.deterministic, False) | ||
| [row] = self.spark.sql("SELECT randInt()").collect() | ||
| self.assertEqual(row[0], "6") | ||
| [row] = self.spark.range(1).select(random_udf1()).collect() | ||
| self.assertEqual(row[0], "6") | ||
| [row] = self.spark.range(1).select(random_udf()).collect() | ||
| self.assertEqual(row[0], 6) | ||
| pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does it do?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to test a help function. See https://github.com/gatorsmile/spark/blob/85f11bfbfb564acb670097ff4ce520bfbc79b855/python/pyspark/sql/tests.py#L1681-L1688
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we put this tests there or make this separate from
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will add a comment. |
||
| pydoc.render_doc(random_udf) | ||
| pydoc.render_doc(random_udf1) | ||
|
|
||
| def test_chained_udf(self): | ||
| self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) | ||
| [row] = self.spark.sql("SELECT double(1)").collect() | ||
|
|
@@ -567,15 +584,13 @@ def test_read_multiple_orc_file(self): | |
|
|
||
| def test_udf_with_input_file_name(self): | ||
| from pyspark.sql.functions import udf, input_file_name | ||
| from pyspark.sql.types import StringType | ||
| sourceFile = udf(lambda path: path, StringType()) | ||
| filePath = "python/test_support/sql/people1.json" | ||
| row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() | ||
| self.assertTrue(row[0].find("people1.json") != -1) | ||
|
|
||
| def test_udf_with_input_file_name_for_hadooprdd(self): | ||
| from pyspark.sql.functions import udf, input_file_name | ||
| from pyspark.sql.types import StringType | ||
|
|
||
| def filename(path): | ||
| return path | ||
|
|
@@ -635,7 +650,6 @@ def test_udf_with_string_return_type(self): | |
|
|
||
| def test_udf_shouldnt_accept_noncallable_object(self): | ||
| from pyspark.sql.functions import UserDefinedFunction | ||
| from pyspark.sql.types import StringType | ||
|
|
||
| non_callable = None | ||
| self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) | ||
|
|
@@ -1299,7 +1313,6 @@ def test_between_function(self): | |
| df.filter(df.a.between(df.b, df.c)).collect()) | ||
|
|
||
| def test_struct_type(self): | ||
| from pyspark.sql.types import StructType, StringType, StructField | ||
| struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) | ||
| struct2 = StructType([StructField("f1", StringType(), True), | ||
| StructField("f2", StringType(), True, None)]) | ||
|
|
@@ -1368,7 +1381,6 @@ def test_parse_datatype_string(self): | |
| _parse_datatype_string("a INT, c DOUBLE")) | ||
|
|
||
| def test_metadata_null(self): | ||
| from pyspark.sql.types import StructType, StringType, StructField | ||
| schema = StructType([StructField("f1", StringType(), True, None), | ||
| StructField("f2", StringType(), True, {'a': None})]) | ||
| rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType): | |||
| ) | ||||
|
|
||||
| # Set the name of the UserDefinedFunction object to be the name of function f | ||||
| udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType) | ||||
| udf_obj = UserDefinedFunction( | ||||
| f, returnType=returnType, name=None, evalType=evalType, deterministic=True) | ||||
| return udf_obj._wrapped() | ||||
|
|
||||
|
|
||||
|
|
@@ -67,8 +68,10 @@ class UserDefinedFunction(object): | |||
| .. versionadded:: 1.3 | ||||
| """ | ||||
| def __init__(self, func, | ||||
| returnType=StringType(), name=None, | ||||
| evalType=PythonEvalType.SQL_BATCHED_UDF): | ||||
| returnType=StringType(), | ||||
| name=None, | ||||
| evalType=PythonEvalType.SQL_BATCHED_UDF, | ||||
| deterministic=True): | ||||
| if not callable(func): | ||||
| raise TypeError( | ||||
| "Invalid function: not a function or callable (__call__ is not defined): " | ||||
|
|
@@ -92,7 +95,7 @@ def __init__(self, func, | |||
| func.__name__ if hasattr(func, '__name__') | ||||
| else func.__class__.__name__) | ||||
| self.evalType = evalType | ||||
| self._deterministic = True | ||||
| self.deterministic = deterministic | ||||
|
|
||||
| @property | ||||
| def returnType(self): | ||||
|
|
@@ -130,7 +133,7 @@ def _create_judf(self): | |||
| wrapped_func = _wrap_function(sc, self.func, self.returnType) | ||||
| jdt = spark._jsparkSession.parseDataType(self.returnType.json()) | ||||
| judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( | ||||
| self._name, wrapped_func, jdt, self.evalType, self._deterministic) | ||||
| self._name, wrapped_func, jdt, self.evalType, self.deterministic) | ||||
| return judf | ||||
|
|
||||
| def __call__(self, *cols): | ||||
|
|
@@ -162,7 +165,8 @@ def wrapper(*args): | |||
| wrapper.func = self.func | ||||
| wrapper.returnType = self.returnType | ||||
| wrapper.evalType = self.evalType | ||||
| wrapper.asNondeterministic = self.asNondeterministic | ||||
| wrapper.deterministic = self.deterministic | ||||
| wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped() | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we do: wrapper.asNondeterministic = functools.wraps(
self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped())So that it can produce a proper pydoc when we do
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good to know the difference
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will leave this unchanged. Maybe you can submit a follow-up PR to address it?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely. Will give a try within the following week tho ... |
||||
|
|
||||
| return wrapper | ||||
|
|
||||
|
|
@@ -172,5 +176,5 @@ def asNondeterministic(self): | |||
|
|
||||
| .. versionadded:: 2.3 | ||||
| """ | ||||
| self._deterministic = False | ||||
| self.deterministic = False | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we call it spark/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala Line 33 in ff48b1b
The opposite works fine to me too if that's possible in any way.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||
| return self | ||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's fix the doc for this too. It says
:param f: python functionbut we could describe that it takes Python native function, wrapped function andUserDefinedFunctiontoo.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok