Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2108,7 +2108,8 @@ def udf(f=None, returnType=StringType()):
can fail on special rows, the workaround is to incorporate the condition into the functions.

:param f: python function if used as a standalone function
:param returnType: a :class:`pyspark.sql.types.DataType` object
:param returnType: the return type of the registered user-defined function. The value can be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems typo: the return type of the registered user-defined function. -> the return type of the user-defined function.?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I'll fix it. Thanks!

either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.

>>> from pyspark.sql.types import IntegerType
>>> slen = udf(lambda s: len(s), IntegerType())
Expand Down Expand Up @@ -2148,7 +2149,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
Creates a vectorized user defined function (UDF).

:param f: user-defined function. A python function if used as a standalone function
:param returnType: a :class:`pyspark.sql.types.DataType` object
:param returnType: the return type of the registered user-defined function. The value can be
either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
:param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`.
Default: SCALAR.

Expand Down
14 changes: 12 additions & 2 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def register(self, name, f, returnType=None):
:param f: a Python function, or a user-defined function. The user-defined function can
be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and
:meth:`pyspark.sql.functions.pandas_udf`.
:param returnType: the return type of the registered user-defined function.
:param returnType: the return type of the registered user-defined function. The value can
be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
:return: a user-defined function.

`returnType` can be optionally specified when `f` is a Python function but not
Expand Down Expand Up @@ -303,21 +304,30 @@ def registerJavaFunction(self, name, javaClassName, returnType=None):

:param name: name of the user-defined function
:param javaClassName: fully qualified name of java class
:param returnType: a :class:`pyspark.sql.types.DataType` object
:param returnType: the return type of the registered Java function. The value can be either
a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.

>>> from pyspark.sql.types import IntegerType
>>> spark.udf.registerJavaFunction(
... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, seems we need to fix :param returnType: across all other related APIs saying it takes DDL-formatted type string.

@ueshin, mind opening a minor PR for this - udf, pandas_udf, registerJavaFunction and register separately? If you are busy, will do it tonight. Doing this here is fine to me too, up to you.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll update them here soon.

>>> spark.sql("SELECT javaStringLength('test')").collect()
[Row(UDF:javaStringLength(test)=4)]

>>> spark.udf.registerJavaFunction(
... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
>>> spark.sql("SELECT javaStringLength2('test')").collect()
[Row(UDF:javaStringLength2(test)=4)]

>>> spark.udf.registerJavaFunction(
... "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer")
>>> spark.sql("SELECT javaStringLength3('test')").collect()
[Row(UDF:javaStringLength3(test)=4)]
"""

jdt = None
if returnType is not None:
if not isinstance(returnType, DataType):
returnType = _parse_datatype_string(returnType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The param doc needs to be modified too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that's #20307 (comment) :).

jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)

Expand Down