Skip to content

Commit e0421c6

Browse files
ueshinHyukjinKwon
authored andcommitted
[SPARK-23141][SQL][PYSPARK] Support data type string as a returnType for registerJavaFunction.
## What changes were proposed in this pull request? Currently `UDFRegistration.registerJavaFunction` doesn't support data type string as a `returnType` whereas `UDFRegistration.register`, `udf`, or `pandas_udf` does. We can support it for `UDFRegistration.registerJavaFunction` as well. ## How was this patch tested? Added a doctest and existing tests. Author: Takuya UESHIN <[email protected]> Closes #20307 from ueshin/issues/SPARK-23141. (cherry picked from commit 5063b74) Signed-off-by: hyukjinkwon <[email protected]>
1 parent 8a98274 commit e0421c6

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

python/pyspark/sql/functions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2108,7 +2108,8 @@ def udf(f=None, returnType=StringType()):
21082108
can fail on special rows, the workaround is to incorporate the condition into the functions.
21092109
21102110
:param f: python function if used as a standalone function
2111-
:param returnType: a :class:`pyspark.sql.types.DataType` object
2111+
:param returnType: the return type of the user-defined function. The value can be either a
2112+
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
21122113
21132114
>>> from pyspark.sql.types import IntegerType
21142115
>>> slen = udf(lambda s: len(s), IntegerType())
@@ -2148,7 +2149,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
21482149
Creates a vectorized user defined function (UDF).
21492150
21502151
:param f: user-defined function. A python function if used as a standalone function
2151-
:param returnType: a :class:`pyspark.sql.types.DataType` object
2152+
:param returnType: the return type of the user-defined function. The value can be either a
2153+
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
21522154
:param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`.
21532155
Default: SCALAR.
21542156

python/pyspark/sql/udf.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ def register(self, name, f, returnType=None):
206206
:param f: a Python function, or a user-defined function. The user-defined function can
207207
be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and
208208
:meth:`pyspark.sql.functions.pandas_udf`.
209-
:param returnType: the return type of the registered user-defined function.
209+
:param returnType: the return type of the registered user-defined function. The value can
210+
be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
210211
:return: a user-defined function.
211212
212213
`returnType` can be optionally specified when `f` is a Python function but not
@@ -303,21 +304,30 @@ def registerJavaFunction(self, name, javaClassName, returnType=None):
303304
304305
:param name: name of the user-defined function
305306
:param javaClassName: fully qualified name of java class
306-
:param returnType: a :class:`pyspark.sql.types.DataType` object
307+
:param returnType: the return type of the registered Java function. The value can be either
308+
a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
307309
308310
>>> from pyspark.sql.types import IntegerType
309311
>>> spark.udf.registerJavaFunction(
310312
... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType())
311313
>>> spark.sql("SELECT javaStringLength('test')").collect()
312314
[Row(UDF:javaStringLength(test)=4)]
315+
313316
>>> spark.udf.registerJavaFunction(
314317
... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
315318
>>> spark.sql("SELECT javaStringLength2('test')").collect()
316319
[Row(UDF:javaStringLength2(test)=4)]
320+
321+
>>> spark.udf.registerJavaFunction(
322+
... "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer")
323+
>>> spark.sql("SELECT javaStringLength3('test')").collect()
324+
[Row(UDF:javaStringLength3(test)=4)]
317325
"""
318326

319327
jdt = None
320328
if returnType is not None:
329+
if not isinstance(returnType, DataType):
330+
returnType = _parse_datatype_string(returnType)
321331
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
322332
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
323333

0 commit comments

Comments
 (0)