Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
Supports Spark Connect.

.. versionchanged:: 4.0.0
Supports keyword-arguments in SCALAR type.
Supports keyword-arguments in SCALAR and GROUPED_AGG type.

Parameters
----------
Expand Down Expand Up @@ -267,6 +267,24 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
| 2| 6.0|
+---+-----------+

This type of Pandas UDF can use keyword arguments:

>>> @pandas_udf("double")
... def weighted_mean_udf(v: pd.Series, w: pd.Series) -> float:
... import numpy as np
... return np.average(v, weights=w)
...
>>> df = spark.createDataFrame(
... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)],
... ("id", "v", "w"))
>>> df.groupby("id").agg(weighted_mean_udf(w=df["w"], v=df["v"])).show()
+---+---------------------------------+
| id|weighted_mean_udf(w => w, v => v)|
+---+---------------------------------+
| 1| 1.6666666666666667|
| 2| 7.166666666666667|
+---+---------------------------------+

This UDF can also be used as window functions as below:

>>> from pyspark.sql import Window
Expand Down
126 changes: 124 additions & 2 deletions python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
PandasUDFType,
)
from pyspark.sql.types import ArrayType, YearMonthIntervalType
from pyspark.errors import AnalysisException, PySparkNotImplementedError
from pyspark.errors import AnalysisException, PySparkNotImplementedError, PythonException
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.testing.utils import QuietTest
from pyspark.testing.utils import QuietTest, assertDataFrameEqual


if have_pandas:
Expand Down Expand Up @@ -575,6 +575,128 @@ def mean(x):

assert filtered.collect()[0]["mean"] == 42.0

def test_named_arguments(self):
df = self.data
weighted_mean = self.pandas_agg_weighted_mean_udf

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

for i, aggregated in enumerate(
[
df.groupby("id").agg(weighted_mean(df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")),
self.spark.sql("SELECT id, weighted_mean(v, w => w) as wm FROM v GROUP BY id"),
self.spark.sql(
"SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id"
),
self.spark.sql(
"SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id"
),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm")))

def test_named_arguments_negative(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please also port these negative tests to the 'kwargs' Pandas UDF case as well? It would be good to make sure we do the same checks there. Same for the window functions testing.

df = self.data
weighted_mean = self.pandas_agg_weighted_mean_udf

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for this test! In addition could we add a negative test case where we provide a positional argument matching the first parameter of the function, then a named argument with the same name as that first parameter? It should return an error in that case. Same for the window functions testing.

).show()

with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql(
"SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id"
).show()

with self.assertRaisesRegex(
PythonException, r"weighted_mean\(\) got an unexpected keyword argument 'x'"
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, x => w) as wm FROM v GROUP BY id"
).show()

def test_kwargs(self):
df = self.data

@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def weighted_mean(**kwargs):
import numpy as np

return np.average(kwargs["v"], weights=kwargs["w"])

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

for i, aggregated in enumerate(
[
df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")),
self.spark.sql(
"SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id"
),
self.spark.sql(
"SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id"
),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm")))

def test_named_arguments_and_defaults(self):
df = self.data

@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def biased_sum(v, w=None):
return v.sum() + (w.sum() if w is not None else 100)

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("biased_sum", biased_sum)

# without "w"
for i, aggregated in enumerate(
[
df.groupby("id").agg(biased_sum(df.v).alias("s")),
df.groupby("id").agg(biased_sum(v=df.v).alias("s")),
self.spark.sql("SELECT id, biased_sum(v) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(v => v) as s FROM v GROUP BY id"),
]
):
with self.subTest(with_w=False, query_no=i):
assertDataFrameEqual(
aggregated, df.groupby("id").agg((sum(df.v) + lit(100)).alias("s"))
)

# with "w"
for i, aggregated in enumerate(
[
df.groupby("id").agg(biased_sum(df.v, w=df.w).alias("s")),
df.groupby("id").agg(biased_sum(v=df.v, w=df.w).alias("s")),
df.groupby("id").agg(biased_sum(w=df.w, v=df.v).alias("s")),
self.spark.sql("SELECT id, biased_sum(v, w => w) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(v => v, w => w) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(w => w, v => v) as s FROM v GROUP BY id"),
]
):
with self.subTest(with_w=True, query_no=i):
assertDataFrameEqual(
aggregated, df.groupby("id").agg((sum(df.v) + sum(df.w)).alias("s"))
)


class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, ReusedSQLTestCase):
pass
Expand Down
139 changes: 137 additions & 2 deletions python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import unittest
from typing import cast

from pyspark.errors import AnalysisException
from pyspark.errors import AnalysisException, PythonException
from pyspark.sql.functions import (
array,
explode,
Expand All @@ -40,7 +40,7 @@
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.testing.utils import QuietTest
from pyspark.testing.utils import QuietTest, assertDataFrameEqual

if have_pandas:
from pandas.testing import assert_frame_equal
Expand Down Expand Up @@ -107,6 +107,16 @@ def min(v):

return min

@property
def pandas_agg_weighted_mean_udf(self):
import numpy as np

@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def weighted_mean(v, w):
return np.average(v, weights=w)

return weighted_mean

@property
def unbounded_window(self):
return (
Expand Down Expand Up @@ -394,6 +404,131 @@ def test_bounded_mixed(self):

assert_frame_equal(expected1.toPandas(), result1.toPandas())

def test_named_arguments(self):
df = self.data
weighted_mean = self.pandas_agg_weighted_mean_udf

for w, bound in [(self.sliding_row_window, True), (self.unbounded_window, False)]:
for i, windowed in enumerate(
[
df.withColumn("wm", weighted_mean(df.v, w=df.w).over(w)),
df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)),
df.withColumn("wm", weighted_mean(w=df.w, v=df.v).over(w)),
]
):
with self.subTest(bound=bound, query_no=i):
assertDataFrameEqual(windowed, df.withColumn("wm", mean(df.v).over(w)))

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

for w in [
"ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING",
"ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING",
]:
window_spec = f"PARTITION BY id ORDER BY v {w}"
for i, func_call in enumerate(
[
"weighted_mean(v, w => w)",
"weighted_mean(v => v, w => w)",
"weighted_mean(w => w, v => v)",
]
):
with self.subTest(window_spec=window_spec, query_no=i):
assertDataFrameEqual(
self.spark.sql(
f"SELECT id, {func_call} OVER ({window_spec}) as wm FROM v"
),
self.spark.sql(f"SELECT id, mean(v) OVER ({window_spec}) as wm FROM v"),
)

def test_named_arguments_negative(self):
df = self.data
weighted_mean = self.pandas_agg_weighted_mean_udf

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

base_sql = "SELECT id, {func_call} OVER ({window_spec}) as wm FROM v"

for w in [
"ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING",
"ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING",
]:
window_spec = f"PARTITION BY id ORDER BY v {w}"
with self.subTest(window_spec=window_spec):
with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql(
base_sql.format(
func_call="weighted_mean(v => v, v => w)", window_spec=window_spec
)
).show()

with self.assertRaisesRegex(
AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"
):
self.spark.sql(
base_sql.format(
func_call="weighted_mean(v => v, w)", window_spec=window_spec
)
).show()

with self.assertRaisesRegex(
PythonException, r"weighted_mean\(\) got an unexpected keyword argument 'x'"
):
self.spark.sql(
base_sql.format(
func_call="weighted_mean(v => v, x => w)", window_spec=window_spec
)
).show()

def test_kwargs(self):
df = self.data

@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def weighted_mean(**kwargs):
import numpy as np

return np.average(kwargs["v"], weights=kwargs["w"])

for w, bound in [(self.sliding_row_window, True), (self.unbounded_window, False)]:
for i, windowed in enumerate(
[
df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)),
df.withColumn("wm", weighted_mean(w=df.w, v=df.v).over(w)),
]
):
with self.subTest(bound=bound, query_no=i):
assertDataFrameEqual(windowed, df.withColumn("wm", mean(df.v).over(w)))

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

for w in [
"ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING",
"ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING",
]:
window_spec = f"PARTITION BY id ORDER BY v {w}"
for i, func_call in enumerate(
[
"weighted_mean(v => v, w => w)",
"weighted_mean(w => w, v => v)",
]
):
with self.subTest(window_spec=window_spec, query_no=i):
assertDataFrameEqual(
self.spark.sql(
f"SELECT id, {func_call} OVER ({window_spec}) as wm FROM v"
),
self.spark.sql(f"SELECT id, mean(v) OVER ({window_spec}) as wm FROM v"),
)


class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase):
pass
Expand Down
Loading