Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
import functools
import warnings
from inspect import getfullargspec
from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union
from typing import cast, Callable, Any, List, TYPE_CHECKING, Optional, Union

from pyspark.rdd import PythonEvalType
from pyspark.sql.connect.expressions import (
ColumnReference,
PythonUDF,
CommonInlineUserDefinedFunction,
Expression,
NamedArgumentExpression,
PythonUDF,
)
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.types import UnparsedDataType
Expand Down Expand Up @@ -155,12 +157,14 @@ def __init__(
self.deterministic = deterministic

def _build_common_inline_user_defined_function(
self, *cols: "ColumnOrName"
self, *args: "ColumnOrName", **kwargs: "ColumnOrName"
) -> CommonInlineUserDefinedFunction:
arg_cols = [
col if isinstance(col, Column) else Column(ColumnReference(col)) for col in cols
def to_expr(col: "ColumnOrName") -> Expression:
return col._expr if isinstance(col, Column) else ColumnReference(col)

arg_exprs: List[Expression] = [to_expr(arg) for arg in args] + [
NamedArgumentExpression(key, to_expr(value)) for key, value in kwargs.items()
]
arg_exprs = [col._expr for col in arg_cols]

py_udf = PythonUDF(
output_type=self.returnType,
Expand All @@ -175,8 +179,8 @@ def _build_common_inline_user_defined_function(
arguments=arg_exprs,
)

def __call__(self, *cols: "ColumnOrName") -> Column:
return Column(self._build_common_inline_user_defined_function(*cols))
def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column:
return Column(self._build_common_inline_user_defined_function(*args, **kwargs))

# This function is for improving the online help system in the interactive interpreter.
# For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and
Expand All @@ -196,8 +200,8 @@ def _wrapped(self) -> "UserDefinedFunctionLike":
)

@functools.wraps(self.func, assigned=assignments)
def wrapper(*args: "ColumnOrName") -> Column:
return self(*args)
def wrapper(*args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column:
return self(*args, **kwargs)

wrapper.__name__ = self._name
wrapper.__module__ = (
Expand Down
17 changes: 17 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15375,6 +15375,9 @@ def udf(
.. versionchanged:: 3.4.0
Supports Spark Connect.

.. versionchanged:: 4.0.0
Supports keyword-arguments.

Parameters
----------
f : function
Expand Down Expand Up @@ -15408,6 +15411,20 @@ def udf(
| 8| JOHN DOE| 22|
+----------+--------------+------------+

UDF can use keyword arguments:

>>> @udf(returnType=IntegerType())
... def calc(a, b):
... return a + 10 * b
...
>>> spark.range(2).select(calc(b=col("id") * 10, a=col("id"))).show()
+-----------------------------+
|calc(b => (id * 10), a => id)|
+-----------------------------+
| 0|
| 101|
+-----------------------------+

Notes
-----
The user-defined functions are considered deterministic by default. Due to
Expand Down
17 changes: 17 additions & 0 deletions python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
.. versionchanged:: 3.4.0
Supports Spark Connect.

.. versionchanged:: 4.0.0
Supports keyword-arguments in SCALAR type.

Parameters
----------
f : function, optional
Expand Down Expand Up @@ -153,6 +156,20 @@ def pandas_udf(f=None, returnType=None, functionType=None):
| [John, Doe]|
+------------------+

This type of Pandas UDF can use keyword arguments:

>>> @pandas_udf(returnType=IntegerType())
... def calc(a: pd.Series, b: pd.Series) -> pd.Series:
... return a + 10 * b
...
>>> spark.range(2).select(calc(b=col("id") * 10, a=col("id"))).show()
+-----------------------------+
|calc(b => (id * 10), a => id)|
+-----------------------------+
| 0|
| 101|
+-----------------------------+

.. note:: The length of the input is not that of the whole input column, but is the
length of an internal batch used for each call to the function.

Expand Down
94 changes: 93 additions & 1 deletion python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
BinaryType,
YearMonthIntervalType,
)
from pyspark.errors import AnalysisException
from pyspark.errors import AnalysisException, PythonException
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
test_compiled,
Expand Down Expand Up @@ -1467,6 +1467,98 @@ def udf(x):
finally:
shutil.rmtree(path)

def test_named_arguments(self):
@pandas_udf("int")
def test_udf(a, b):
return a + 10 * b

self.spark.udf.register("test_udf", test_udf)

for i, df in enumerate(
[
self.spark.range(2).select(test_udf(col("id"), b=col("id") * 10)),
self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 10)),
self.spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))),
self.spark.sql("SELECT test_udf(id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(0), Row(101)])

def test_named_arguments_negative(self):
@pandas_udf("int")
def test_udf(a, b):
return a + b

self.spark.udf.register("test_udf", test_udf)

with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show()

with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show()

with self.assertRaisesRegex(
PythonException, r"test_udf\(\) got an unexpected keyword argument 'c'"
):
self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()

def test_kwargs(self):
@pandas_udf("int")
def test_udf(a, **kwargs):
return a + 10 * kwargs["b"]

self.spark.udf.register("test_udf", test_udf)

for i, df in enumerate(
[
self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 10)),
self.spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))),
self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(0), Row(101)])

def test_named_arguments_and_defaults(self):
@pandas_udf("int")
def test_udf(a, b=0):
return a + 10 * b

self.spark.udf.register("test_udf", test_udf)

# without "b"
for i, df in enumerate(
[
self.spark.range(2).select(test_udf(col("id"))),
self.spark.range(2).select(test_udf(a=col("id"))),
self.spark.sql("SELECT test_udf(id) FROM range(2)"),
self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
]
):
with self.subTest(with_b=False, query_no=i):
assertDataFrameEqual(df, [Row(0), Row(1)])

# with "b"
for i, df in enumerate(
[
self.spark.range(2).select(test_udf(col("id"), b=col("id") * 10)),
self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 10)),
self.spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))),
self.spark.sql("SELECT test_udf(id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"),
]
):
with self.subTest(with_b=True, query_no=i):
assertDataFrameEqual(df, [Row(0), Row(101)])


class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
98 changes: 95 additions & 3 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from pyspark import SparkContext, SQLContext
from pyspark.sql import SparkSession, Column, Row
from pyspark.sql.functions import udf, assert_true, lit, rand
from pyspark.sql.functions import col, udf, assert_true, lit, rand
from pyspark.sql.udf import UserDefinedFunction
from pyspark.sql.types import (
StringType,
Expand All @@ -38,9 +38,9 @@
TimestampNTZType,
DayTimeIntervalType,
)
from pyspark.errors import AnalysisException, PySparkTypeError
from pyspark.errors import AnalysisException, PythonException, PySparkTypeError
from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
from pyspark.testing.utils import QuietTest
from pyspark.testing.utils import QuietTest, assertDataFrameEqual


class BaseUDFTestsMixin(object):
Expand Down Expand Up @@ -898,6 +898,98 @@ def test_complex_return_types(self):
self.assertEquals(row[1], {"a": "b"})
self.assertEquals(row[2], Row(col1=1, col2=2))

def test_named_arguments(self):
@udf("int")
def test_udf(a, b):
return a + 10 * b

self.spark.udf.register("test_udf", test_udf)

for i, df in enumerate(
[
self.spark.range(2).select(test_udf(col("id"), b=col("id") * 10)),
self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 10)),
self.spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))),
self.spark.sql("SELECT test_udf(id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(0), Row(101)])

def test_named_arguments_negative(self):
@udf("int")
def test_udf(a, b):
return a + b

self.spark.udf.register("test_udf", test_udf)

with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show()

with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show()

with self.assertRaisesRegex(
PythonException, r"test_udf\(\) got an unexpected keyword argument 'c'"
):
self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()

def test_kwargs(self):
@udf("int")
def test_udf(**kwargs):
return kwargs["a"] + 10 * kwargs["b"]

self.spark.udf.register("test_udf", test_udf)

for i, df in enumerate(
[
self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 10)),
self.spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))),
self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(0), Row(101)])

def test_named_arguments_and_defaults(self):
@udf("int")
def test_udf(a, b=0):
return a + 10 * b

self.spark.udf.register("test_udf", test_udf)

# without "b"
for i, df in enumerate(
[
self.spark.range(2).select(test_udf(col("id"))),
self.spark.range(2).select(test_udf(a=col("id"))),
self.spark.sql("SELECT test_udf(id) FROM range(2)"),
self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
]
):
with self.subTest(with_b=False, query_no=i):
assertDataFrameEqual(df, [Row(0), Row(1)])

# with "b"
for i, df in enumerate(
[
self.spark.range(2).select(test_udf(col("id"), b=col("id") * 10)),
self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 10)),
self.spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))),
self.spark.sql("SELECT test_udf(id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"),
]
):
with self.subTest(with_b=True, query_no=i):
assertDataFrameEqual(df, [Row(0), Row(101)])


class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1947,7 +1947,7 @@ def eval(self, a, b=100):
TestUDTF(a=lit(10)),
]
):
with self.subTest(query_no=i):
with self.subTest(with_b=False, query_no=i):
assertDataFrameEqual(df, [Row(a=10, b=100)])

# with "b"
Expand All @@ -1961,7 +1961,7 @@ def eval(self, a, b=100):
TestUDTF(b=lit("z"), a=lit(10)),
]
):
with self.subTest(query_no=i):
with self.subTest(with_b=True, query_no=i):
assertDataFrameEqual(df, [Row(a=10, b="z")])

def test_udtf_with_table_argument_and_partition_by(self):
Expand Down
Loading