Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import tempfile
import unittest

import py4j

from pyspark import SparkContext
from pyspark.sql import SparkSession, Column, Row
from pyspark.sql.functions import UserDefinedFunction, udf
Expand Down Expand Up @@ -357,6 +359,32 @@ def test_udf_registration_returns_udf(self):
df.select(add_four("id").alias("plus_four")).collect()
)

@unittest.skipIf(not test_compiled, test_not_compiled_message)
def test_register_java_function(self):
self.spark.udf.registerJavaFunction(
"javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType())
[value] = self.spark.sql("SELECT javaStringLength('test')").first()
self.assertEqual(value, 4)

self.spark.udf.registerJavaFunction(
"javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
[value] = self.spark.sql("SELECT javaStringLength2('test')").first()
self.assertEqual(value, 4)

self.spark.udf.registerJavaFunction(
"javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer")
[value] = self.spark.sql("SELECT javaStringLength3('test')").first()
self.assertEqual(value, 4)

@unittest.skipIf(not test_compiled, test_not_compiled_message)
def test_register_java_udaf(self):
self.spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg")
df = self.spark.createDataFrame([(1, "a"), (2, "b"), (3, "a")], ["id", "name"])
df.createOrReplaceTempView("df")
row = self.spark.sql(
"SELECT name, javaUDAF(id) as avg from df group by name order by name desc").first()
self.assertEqual(row.asDict(), Row(name='b', avg=102.0).asDict())
Copy link
Member

@dongjoon-hyun dongjoon-hyun Jun 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, we don't compare [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could compare them as are. It's just to prevent an issue such as SPARK-29748 or similar issues in the future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it~ No problem~


def test_non_existed_udf(self):
spark = self.spark
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
Expand Down
14 changes: 9 additions & 5 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,17 +365,20 @@ def registerJavaFunction(self, name, javaClassName, returnType=None):
>>> from pyspark.sql.types import IntegerType
>>> spark.udf.registerJavaFunction(
... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType())
>>> spark.sql("SELECT javaStringLength('test')").collect()
... # doctest: +SKIP
>>> spark.sql("SELECT javaStringLength('test')").collect() # doctest: +SKIP
[Row(javaStringLength(test)=4)]

>>> spark.udf.registerJavaFunction(
... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
>>> spark.sql("SELECT javaStringLength2('test')").collect()
... # doctest: +SKIP
>>> spark.sql("SELECT javaStringLength2('test')").collect() # doctest: +SKIP
[Row(javaStringLength2(test)=4)]

>>> spark.udf.registerJavaFunction(
... "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer")
>>> spark.sql("SELECT javaStringLength3('test')").collect()
... # doctest: +SKIP
>>> spark.sql("SELECT javaStringLength3('test')").collect() # doctest: +SKIP
[Row(javaStringLength3(test)=4)]
"""

Expand All @@ -395,10 +398,11 @@ def registerJavaUDAF(self, name, javaClassName):
:param javaClassName: fully qualified name of java class

>>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg")
... # doctest: +SKIP
>>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
>>> df.createOrReplaceTempView("df")
>>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name order by name desc") \
.collect()
>>> q = "SELECT name, javaUDAF(id) as avg from df group by name order by name desc"
>>> spark.sql(q).collect() # doctest: +SKIP
[Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
"""

Expand Down