Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
Supports Spark Connect.

.. versionchanged:: 4.0.0
Supports keyword-arguments in SCALAR type.
Supports keyword-arguments in SCALAR and GROUPED_AGG type.

Parameters
----------
Expand Down Expand Up @@ -267,6 +267,24 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
| 2| 6.0|
+---+-----------+

This type of Pandas UDF can use keyword arguments:

>>> @pandas_udf("double")
... def weighted_mean_udf(v: pd.Series, w: pd.Series) -> float:
... import numpy as np
... return np.average(v, weights=w)
...
>>> df = spark.createDataFrame(
... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)],
... ("id", "v", "w"))
>>> df.groupby("id").agg(weighted_mean_udf(w=df["w"], v=df["v"])).show()
+---+---------------------------------+
| id|weighted_mean_udf(w => w, v => v)|
+---+---------------------------------+
| 1| 1.6666666666666667|
| 2| 7.166666666666667|
+---+---------------------------------+

This UDF can also be used as window functions as below:

>>> from pyspark.sql import Window
Expand Down
147 changes: 145 additions & 2 deletions python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
PandasUDFType,
)
from pyspark.sql.types import ArrayType, YearMonthIntervalType
from pyspark.errors import AnalysisException, PySparkNotImplementedError
from pyspark.errors import AnalysisException, PySparkNotImplementedError, PythonException
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.testing.utils import QuietTest
from pyspark.testing.utils import QuietTest, assertDataFrameEqual


if have_pandas:
Expand Down Expand Up @@ -575,6 +575,149 @@ def mean(x):

assert filtered.collect()[0]["mean"] == 42.0

def test_named_arguments(self):
df = self.data
weighted_mean = self.pandas_agg_weighted_mean_udf

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

for i, aggregated in enumerate(
[
df.groupby("id").agg(weighted_mean(df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")),
self.spark.sql("SELECT id, weighted_mean(v, w => w) as wm FROM v GROUP BY id"),
self.spark.sql(
"SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id"
),
self.spark.sql(
"SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id"
),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm")))

def test_named_arguments_negative(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please also port these negative tests to the 'kwargs' Pandas UDF case as well? It would be good to make sure we do the same checks there. Same for the window functions testing.

df = self.data
weighted_mean = self.pandas_agg_weighted_mean_udf

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for this test! In addition could we add a negative test case where we provide a positional argument matching the first parameter of the function, then a named argument with the same name as that first parameter? It should return an error in that case. Same for the window functions testing.

).show()

with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql(
"SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id"
).show()

with self.assertRaisesRegex(
PythonException, r"weighted_mean\(\) got an unexpected keyword argument 'x'"
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, x => w) as wm FROM v GROUP BY id"
).show()

with self.assertRaisesRegex(
PythonException, r"weighted_mean\(\) got multiple values for argument 'v'"
):
self.spark.sql(
"SELECT id, weighted_mean(v, v => w) as wm FROM v GROUP BY id"
).show()

def test_kwargs(self):
df = self.data

@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def weighted_mean(**kwargs):
import numpy as np

return np.average(kwargs["v"], weights=kwargs["w"])

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

for i, aggregated in enumerate(
[
df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")),
self.spark.sql(
"SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id"
),
self.spark.sql(
"SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id"
),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm")))

# negative
with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id"
).show()

with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql(
"SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id"
).show()

def test_named_arguments_and_defaults(self):
df = self.data

@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def biased_sum(v, w=None):
return v.sum() + (w.sum() if w is not None else 100)

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("biased_sum", biased_sum)

# without "w"
for i, aggregated in enumerate(
[
df.groupby("id").agg(biased_sum(df.v).alias("s")),
df.groupby("id").agg(biased_sum(v=df.v).alias("s")),
self.spark.sql("SELECT id, biased_sum(v) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(v => v) as s FROM v GROUP BY id"),
]
):
with self.subTest(with_w=False, query_no=i):
assertDataFrameEqual(
aggregated, df.groupby("id").agg((sum(df.v) + lit(100)).alias("s"))
)

# with "w"
for i, aggregated in enumerate(
[
df.groupby("id").agg(biased_sum(df.v, w=df.w).alias("s")),
df.groupby("id").agg(biased_sum(v=df.v, w=df.w).alias("s")),
df.groupby("id").agg(biased_sum(w=df.w, v=df.v).alias("s")),
self.spark.sql("SELECT id, biased_sum(v, w => w) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(v => v, w => w) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(w => w, v => v) as s FROM v GROUP BY id"),
]
):
with self.subTest(with_w=True, query_no=i):
assertDataFrameEqual(
aggregated, df.groupby("id").agg((sum(df.v) + sum(df.w)).alias("s"))
)


class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, ReusedSQLTestCase):
pass
Expand Down
Loading