-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-22939] [PySpark] Support Spark UDF in registerFunction #20137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
8216b6b
0ecdf63
e8d0a4c
35e6a4a
b89b720
3208136
f099261
d1ba703
6ac25e6
85f11bf
78e9b2c
09a1b89
2482e6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -227,15 +227,15 @@ def dropGlobalTempView(self, viewName): | |
| @ignore_unicode_prefix | ||
| @since(2.0) | ||
| def registerFunction(self, name, f, returnType=StringType()): | ||
| """Registers a python function (including lambda function) as a UDF | ||
| """Registers a Python function (including lambda function) or a wrapped/native UDF | ||
| so it can be used in SQL statements. | ||
|
|
||
| In addition to a name and the function itself, the return type can be optionally specified. | ||
| When the return type is not given it default to a string and conversion will automatically | ||
| be done. For any other return type, the produced object must match the specified type. | ||
|
|
||
| :param name: name of the UDF | ||
| :param f: python function | ||
| :param f: a Python function, or a wrapped/native UserDefinedFunction | ||
| :param returnType: a :class:`pyspark.sql.types.DataType` object | ||
| :return: a wrapped :class:`UserDefinedFunction` | ||
|
|
||
|
|
@@ -255,9 +255,26 @@ def registerFunction(self, name, f, returnType=StringType()): | |
| >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) | ||
| >>> spark.sql("SELECT stringLengthInt('test')").collect() | ||
| [Row(stringLengthInt(test)=4)] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's fix the doc for this too. It says
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
|
|
||
| >>> import random | ||
| >>> from pyspark.sql.functions import udf | ||
| >>> from pyspark.sql.types import IntegerType, StringType | ||
| >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() | ||
| >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType()) | ||
| >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP | ||
| [Row(random_udf()=u'82')] | ||
| >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP | ||
| [Row(random_udf()=u'62')] | ||
| """ | ||
| udf = UserDefinedFunction(f, returnType=returnType, name=name, | ||
| evalType=PythonEvalType.SQL_BATCHED_UDF) | ||
|
|
||
| # This is to check whether the input function is a wrapped/native UserDefinedFunction | ||
| if hasattr(f, 'asNondeterministic'): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, this one made me to suggest So, here this can be wrapped function or Could we at least leave come comments saying that this can be both wrapped function for
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will add a comment. |
||
| udf = UserDefinedFunction(f.func, returnType=returnType, name=name, | ||
| evalType=PythonEvalType.SQL_BATCHED_UDF, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @ueshin @icexelloss , shall we support register pandas UDF here too?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems we can support it by just changing
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 but I think there's no way to use a group map UDF in SQL syntax if I understood correctly. I think we can safely fail fast for now as well.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SGTM
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will support the pandas UDF as a separate PR.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 too |
||
| deterministic=f.deterministic) | ||
| else: | ||
| udf = UserDefinedFunction(f, returnType=returnType, name=name, | ||
| evalType=PythonEvalType.SQL_BATCHED_UDF) | ||
| self._jsparkSession.udf().registerPython(name, udf._judf) | ||
| return udf._wrapped() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -378,6 +378,41 @@ def test_udf2(self): | |
| [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() | ||
| self.assertEqual(4, res[0]) | ||
|
|
||
| def test_udf3(self): | ||
| twoargs = self.spark.catalog.registerFunction( | ||
| "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType()) | ||
| self.assertEqual(twoargs.deterministic, True) | ||
| [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() | ||
| self.assertEqual(row[0], 5) | ||
|
|
||
| def test_nondeterministic_udf(self): | ||
| from pyspark.sql.functions import udf | ||
| import random | ||
| udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() | ||
| self.assertEqual(udf_random_col.deterministic, False) | ||
| df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) | ||
| udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) | ||
| [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() | ||
| self.assertEqual(row[0] + 10, row[1]) | ||
|
|
||
| def test_nondeterministic_udf2(self): | ||
| import random | ||
| from pyspark.sql.functions import udf | ||
| random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() | ||
| self.assertEqual(random_udf.deterministic, False) | ||
| random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType()) | ||
| self.assertEqual(random_udf1.deterministic, False) | ||
| [row] = self.spark.sql("SELECT randInt()").collect() | ||
| self.assertEqual(row[0], "6") | ||
| [row] = self.spark.range(1).select(random_udf1()).collect() | ||
| self.assertEqual(row[0], "6") | ||
| [row] = self.spark.range(1).select(random_udf()).collect() | ||
| self.assertEqual(row[0], 6) | ||
| # render_doc() reproduces the help() exception without printing output | ||
| pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does it do?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to test a help function. See https://github.com/gatorsmile/spark/blob/85f11bfbfb564acb670097ff4ce520bfbc79b855/python/pyspark/sql/tests.py#L1681-L1688
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we put this tests there or make this separate from
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will add a comment. |
||
| pydoc.render_doc(random_udf) | ||
| pydoc.render_doc(random_udf1) | ||
|
|
||
| def test_chained_udf(self): | ||
| self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) | ||
| [row] = self.spark.sql("SELECT double(1)").collect() | ||
|
|
@@ -435,15 +470,6 @@ def test_udf_with_array_type(self): | |
| self.assertEqual(list(range(3)), l1) | ||
| self.assertEqual(1, l2) | ||
|
|
||
| def test_nondeterministic_udf(self): | ||
| from pyspark.sql.functions import udf | ||
| import random | ||
| udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() | ||
| df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) | ||
| udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) | ||
| [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() | ||
| self.assertEqual(row[0] + 10, row[1]) | ||
|
|
||
| def test_broadcast_in_udf(self): | ||
| bar = {"a": "aa", "b": "bb", "c": "abc"} | ||
| foo = self.sc.broadcast(bar) | ||
|
|
@@ -567,15 +593,13 @@ def test_read_multiple_orc_file(self): | |
|
|
||
| def test_udf_with_input_file_name(self): | ||
| from pyspark.sql.functions import udf, input_file_name | ||
| from pyspark.sql.types import StringType | ||
| sourceFile = udf(lambda path: path, StringType()) | ||
| filePath = "python/test_support/sql/people1.json" | ||
| row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() | ||
| self.assertTrue(row[0].find("people1.json") != -1) | ||
|
|
||
| def test_udf_with_input_file_name_for_hadooprdd(self): | ||
| from pyspark.sql.functions import udf, input_file_name | ||
| from pyspark.sql.types import StringType | ||
|
|
||
| def filename(path): | ||
| return path | ||
|
|
@@ -635,7 +659,6 @@ def test_udf_with_string_return_type(self): | |
|
|
||
| def test_udf_shouldnt_accept_noncallable_object(self): | ||
| from pyspark.sql.functions import UserDefinedFunction | ||
| from pyspark.sql.types import StringType | ||
|
|
||
| non_callable = None | ||
| self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) | ||
|
|
@@ -1299,7 +1322,6 @@ def test_between_function(self): | |
| df.filter(df.a.between(df.b, df.c)).collect()) | ||
|
|
||
| def test_struct_type(self): | ||
| from pyspark.sql.types import StructType, StringType, StructField | ||
| struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) | ||
| struct2 = StructType([StructField("f1", StringType(), True), | ||
| StructField("f2", StringType(), True, None)]) | ||
|
|
@@ -1368,7 +1390,6 @@ def test_parse_datatype_string(self): | |
| _parse_datatype_string("a INT, c DOUBLE")) | ||
|
|
||
| def test_metadata_null(self): | ||
| from pyspark.sql.types import StructType, StringType, StructField | ||
| schema = StructType([StructField("f1", StringType(), True, None), | ||
| StructField("f2", StringType(), True, {'a': None})]) | ||
| rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType): | |||
| ) | ||||
|
|
||||
| # Set the name of the UserDefinedFunction object to be the name of function f | ||||
| udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType) | ||||
| udf_obj = UserDefinedFunction( | ||||
| f, returnType=returnType, name=None, evalType=evalType, deterministic=True) | ||||
| return udf_obj._wrapped() | ||||
|
|
||||
|
|
||||
|
|
@@ -67,8 +68,10 @@ class UserDefinedFunction(object): | |||
| .. versionadded:: 1.3 | ||||
| """ | ||||
| def __init__(self, func, | ||||
| returnType=StringType(), name=None, | ||||
| evalType=PythonEvalType.SQL_BATCHED_UDF): | ||||
| returnType=StringType(), | ||||
| name=None, | ||||
| evalType=PythonEvalType.SQL_BATCHED_UDF, | ||||
| deterministic=True): | ||||
| if not callable(func): | ||||
| raise TypeError( | ||||
| "Invalid function: not a function or callable (__call__ is not defined): " | ||||
|
|
@@ -92,7 +95,7 @@ def __init__(self, func, | |||
| func.__name__ if hasattr(func, '__name__') | ||||
| else func.__class__.__name__) | ||||
| self.evalType = evalType | ||||
| self._deterministic = True | ||||
| self.deterministic = deterministic | ||||
|
|
||||
| @property | ||||
| def returnType(self): | ||||
|
|
@@ -130,14 +133,17 @@ def _create_judf(self): | |||
| wrapped_func = _wrap_function(sc, self.func, self.returnType) | ||||
| jdt = spark._jsparkSession.parseDataType(self.returnType.json()) | ||||
| judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( | ||||
| self._name, wrapped_func, jdt, self.evalType, self._deterministic) | ||||
| self._name, wrapped_func, jdt, self.evalType, self.deterministic) | ||||
| return judf | ||||
|
|
||||
| def __call__(self, *cols): | ||||
| judf = self._judf | ||||
| sc = SparkContext._active_spark_context | ||||
| return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) | ||||
|
|
||||
| # This function is for improving the online help system in the interactive interpreter. | ||||
| # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and | ||||
| # argument annotation. (See: SPARK-19161) | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can put this in the docstring of
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not want to expose these comments to the doc. |
||||
| def _wrapped(self): | ||||
| """ | ||||
| Wrap this udf with a function and attach docstring from func | ||||
|
|
@@ -162,7 +168,8 @@ def wrapper(*args): | |||
| wrapper.func = self.func | ||||
| wrapper.returnType = self.returnType | ||||
| wrapper.evalType = self.evalType | ||||
| wrapper.asNondeterministic = self.asNondeterministic | ||||
| wrapper.deterministic = self.deterministic | ||||
| wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped() | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we do: wrapper.asNondeterministic = functools.wraps(
self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped())So that it can produce a proper pydoc when we do
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good to know the difference
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will leave this unchanged. Maybe you can submit a follow-up PR to address it?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely. Will give a try within the following week tho ... |
||||
|
|
||||
| return wrapper | ||||
|
|
||||
|
|
@@ -172,5 +179,5 @@ def asNondeterministic(self): | |||
|
|
||||
| .. versionadded:: 2.3 | ||||
| """ | ||||
| self._deterministic = False | ||||
| self.deterministic = False | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we call it spark/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala Line 33 in ff48b1b
The opposite works fine to me too if that's possible in any way.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||
| return self | ||||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm really confusing when reading this document, it would be much more clear to me if we can just say
This wrapping logic was added in #16534 , is it really worth?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It indeed added some complexity. However, I believe nothing is blocked by #16534 now if I understand correctly.
The changes #16534 is quite nice because IMHO Python guys probably use
help()anddir()more frequently then reading the API doc in the website. For the set of UDFs are provided as a library, I think that's quite worth to keep.How about leaving this wrapper logic as is for now and then we bring this discussion back when actually something is blocked (or being too complicated) by this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another idea just in case it helps:
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, to be honest, I remember I gave several quick tries to get rid of the wrapper but keep the docstring correctly at that time but I failed to make a good alternative.
Might be good to try if there is a clever way to get rid of the wrapper but keep the doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM