Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,22 @@ def registerJavaFunction(self, name, javaClassName, returnType=None):
... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, seems we need to fix :param returnType: across all other related APIs saying it takes DDL-formatted type string.

@ueshin, mind opening a minor PR for this - udf, pandas_udf, registerJavaFunction and register separately? If you are busy, will do it tonight. Doing this here is fine to me too, up to you.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll update them here soon.

>>> spark.sql("SELECT javaStringLength('test')").collect()
[Row(UDF:javaStringLength(test)=4)]

>>> spark.udf.registerJavaFunction(
... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
>>> spark.sql("SELECT javaStringLength2('test')").collect()
[Row(UDF:javaStringLength2(test)=4)]

>>> spark.udf.registerJavaFunction(
... "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer")
>>> spark.sql("SELECT javaStringLength3('test')").collect()
[Row(UDF:javaStringLength3(test)=4)]
"""

jdt = None
if returnType is not None:
if not isinstance(returnType, DataType):
returnType = _parse_datatype_string(returnType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The param doc needs to be modified too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that's #20307 (comment) :).

jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)

Expand Down