diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 2689b9c33d576..d673f7c15918f 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -21,6 +21,8 @@ import tempfile import unittest +import py4j + from pyspark import SparkContext from pyspark.sql import SparkSession, Column, Row from pyspark.sql.functions import UserDefinedFunction, udf @@ -357,6 +359,32 @@ def test_udf_registration_returns_udf(self): df.select(add_four("id").alias("plus_four")).collect() ) + @unittest.skipIf(not test_compiled, test_not_compiled_message) + def test_register_java_function(self): + self.spark.udf.registerJavaFunction( + "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType()) + [value] = self.spark.sql("SELECT javaStringLength('test')").first() + self.assertEqual(value, 4) + + self.spark.udf.registerJavaFunction( + "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength") + [value] = self.spark.sql("SELECT javaStringLength2('test')").first() + self.assertEqual(value, 4) + + self.spark.udf.registerJavaFunction( + "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer") + [value] = self.spark.sql("SELECT javaStringLength3('test')").first() + self.assertEqual(value, 4) + + @unittest.skipIf(not test_compiled, test_not_compiled_message) + def test_register_java_udaf(self): + self.spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg") + df = self.spark.createDataFrame([(1, "a"), (2, "b"), (3, "a")], ["id", "name"]) + df.createOrReplaceTempView("df") + row = self.spark.sql( + "SELECT name, javaUDAF(id) as avg from df group by name order by name desc").first() + self.assertEqual(row.asDict(), Row(name='b', avg=102.0).asDict()) + def test_non_existed_udf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 10546ecacc57f..da68583b04e1c 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -365,17 +365,20 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): >>> from pyspark.sql.types import IntegerType >>> spark.udf.registerJavaFunction( ... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType()) - >>> spark.sql("SELECT javaStringLength('test')").collect() + ... # doctest: +SKIP + >>> spark.sql("SELECT javaStringLength('test')").collect() # doctest: +SKIP [Row(javaStringLength(test)=4)] >>> spark.udf.registerJavaFunction( ... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength") - >>> spark.sql("SELECT javaStringLength2('test')").collect() + ... # doctest: +SKIP + >>> spark.sql("SELECT javaStringLength2('test')").collect() # doctest: +SKIP [Row(javaStringLength2(test)=4)] >>> spark.udf.registerJavaFunction( ... "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer") - >>> spark.sql("SELECT javaStringLength3('test')").collect() + ... # doctest: +SKIP + >>> spark.sql("SELECT javaStringLength3('test')").collect() # doctest: +SKIP [Row(javaStringLength3(test)=4)] """ @@ -395,10 +398,11 @@ def registerJavaUDAF(self, name, javaClassName): :param javaClassName: fully qualified name of java class >>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg") + ... # doctest: +SKIP >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) >>> df.createOrReplaceTempView("df") - >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name order by name desc") \ - .collect() + >>> q = "SELECT name, javaUDAF(id) as avg from df group by name order by name desc" + >>> spark.sql(q).collect() # doctest: +SKIP [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] """