Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5642,8 +5642,9 @@ def test_register_grouped_map_udf(self):

foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
with QuietTest(self.sc):
with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or '
'SQL_SCALAR_PANDAS_UDF'):
with self.assertRaisesRegexp(
ValueError,
'f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*'):
self.spark.catalog.registerFunction("foo_udf", foo_udf)

def test_decorator(self):
Expand Down Expand Up @@ -6459,6 +6460,21 @@ def test_invalid_args(self):
'mixture.*aggregate function.*group aggregate pandas UDF'):
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()

def test_register_vectorized_udf_basic(self):
from pyspark.sql.functions import pandas_udf
from pyspark.rdd import PythonEvalType

sum_pandas_udf = pandas_udf(
lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)

self.assertEqual(sum_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
group_agg_pandas_udf = self.spark.udf.register("sum_pandas_udf", sum_pandas_udf)
self.assertEqual(group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
q = "SELECT sum_pandas_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
expected = [1, 5]
self.assertEqual(actual, expected)


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down
15 changes: 13 additions & 2 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,15 @@ def register(self, name, f, returnType=None):
>>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]

>>> @pandas_udf("integer", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
... def sum_udf(v):
... return v.sum()
...
>>> _ = spark.udf.register("sum_udf", sum_udf) # doctest: +SKIP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the "_ =" thing here?

Copy link
Member Author

@HyukjinKwon HyukjinKwon Oct 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hides the output like ...

>>> spark.udf.register("sum_udf", sum_udf)
<function sum_udf at 0x103ff18c0>

in the doctest.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha. I see..

>>> q = "SELECT sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
>>> spark.sql(q).collect() # doctest: +SKIP
[Row(sum_udf(v1)=1), Row(sum_udf(v1)=5)]

.. note:: Registration for a user-defined function (case 2.) was added from
Spark 2.3.0.
"""
Expand All @@ -310,9 +319,11 @@ def register(self, name, f, returnType=None):
"Invalid returnType: data type can not be specified when f is"
"a user-defined function, but got %s." % returnType)
if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_UDF]:
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about SQL_WINDOW_AGG_PANDAS_UDF?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need it here:

Users specify GROUPED_AGG only. GROUPED_AGG is turned to WINDOW_AGG eval type in WindowInPandasExec.

Admittedly, there is a bit confusion here we can improve. We just haven't got a user specified udf type that maps to multiple evalType before WINDOW_AGG.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These need to be clearly defined in Apache Spark 3.0 release; otherwise, it might be confusing to both developers and end users. :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened https://issues.apache.org/jira/browse/SPARK-25640 to track this.

To be clear, this is transparent to end users, but I agree it can be confusing to developers.

raise ValueError(
"Invalid f: f must be either SQL_BATCHED_UDF or SQL_SCALAR_PANDAS_UDF")
"Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF or "
"SQL_GROUPED_AGG_PANDAS_UDF")
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
evalType=f.evalType,
deterministic=f.deterministic)
Expand Down