-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-44952][SQL][PYTHON] Support named arguments in aggregate Pandas UDFs #42663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,15 +32,15 @@ | |
| PandasUDFType, | ||
| ) | ||
| from pyspark.sql.types import ArrayType, YearMonthIntervalType | ||
| from pyspark.errors import AnalysisException, PySparkNotImplementedError | ||
| from pyspark.errors import AnalysisException, PySparkNotImplementedError, PythonException | ||
| from pyspark.testing.sqlutils import ( | ||
| ReusedSQLTestCase, | ||
| have_pandas, | ||
| have_pyarrow, | ||
| pandas_requirement_message, | ||
| pyarrow_requirement_message, | ||
| ) | ||
| from pyspark.testing.utils import QuietTest | ||
| from pyspark.testing.utils import QuietTest, assertDataFrameEqual | ||
|
|
||
|
|
||
| if have_pandas: | ||
|
|
@@ -575,6 +575,149 @@ def mean(x): | |
|
|
||
| assert filtered.collect()[0]["mean"] == 42.0 | ||
|
|
||
| def test_named_arguments(self): | ||
| df = self.data | ||
| weighted_mean = self.pandas_agg_weighted_mean_udf | ||
|
|
||
| with self.tempView("v"): | ||
| df.createOrReplaceTempView("v") | ||
| self.spark.udf.register("weighted_mean", weighted_mean) | ||
|
|
||
| for i, aggregated in enumerate( | ||
| [ | ||
| df.groupby("id").agg(weighted_mean(df.v, w=df.w).alias("wm")), | ||
| df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")), | ||
| df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")), | ||
| self.spark.sql("SELECT id, weighted_mean(v, w => w) as wm FROM v GROUP BY id"), | ||
| self.spark.sql( | ||
| "SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id" | ||
| ), | ||
| self.spark.sql( | ||
| "SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id" | ||
| ), | ||
| ] | ||
| ): | ||
| with self.subTest(query_no=i): | ||
| assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm"))) | ||
|
|
||
| def test_named_arguments_negative(self): | ||
| df = self.data | ||
| weighted_mean = self.pandas_agg_weighted_mean_udf | ||
|
|
||
| with self.tempView("v"): | ||
| df.createOrReplaceTempView("v") | ||
| self.spark.udf.register("weighted_mean", weighted_mean) | ||
|
|
||
| with self.assertRaisesRegex( | ||
| AnalysisException, | ||
| "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", | ||
| ): | ||
| self.spark.sql( | ||
| "SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for this test! In addition could we add a negative test case where we provide a positional argument matching the first parameter of the function, then a named argument with the same name as that first parameter? It should return an error in that case. Same for the window functions testing. |
||
| ).show() | ||
|
|
||
| with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): | ||
| self.spark.sql( | ||
| "SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id" | ||
| ).show() | ||
|
|
||
| with self.assertRaisesRegex( | ||
| PythonException, r"weighted_mean\(\) got an unexpected keyword argument 'x'" | ||
| ): | ||
| self.spark.sql( | ||
| "SELECT id, weighted_mean(v => v, x => w) as wm FROM v GROUP BY id" | ||
| ).show() | ||
|
|
||
| with self.assertRaisesRegex( | ||
| PythonException, r"weighted_mean\(\) got multiple values for argument 'v'" | ||
| ): | ||
| self.spark.sql( | ||
| "SELECT id, weighted_mean(v, v => w) as wm FROM v GROUP BY id" | ||
| ).show() | ||
|
|
||
| def test_kwargs(self): | ||
| df = self.data | ||
|
|
||
| @pandas_udf("double", PandasUDFType.GROUPED_AGG) | ||
| def weighted_mean(**kwargs): | ||
| import numpy as np | ||
|
|
||
| return np.average(kwargs["v"], weights=kwargs["w"]) | ||
|
|
||
| with self.tempView("v"): | ||
| df.createOrReplaceTempView("v") | ||
| self.spark.udf.register("weighted_mean", weighted_mean) | ||
|
|
||
| for i, aggregated in enumerate( | ||
| [ | ||
| df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")), | ||
| df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")), | ||
| self.spark.sql( | ||
| "SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id" | ||
| ), | ||
| self.spark.sql( | ||
| "SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id" | ||
| ), | ||
| ] | ||
| ): | ||
| with self.subTest(query_no=i): | ||
| assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm"))) | ||
|
|
||
| # negative | ||
| with self.assertRaisesRegex( | ||
| AnalysisException, | ||
| "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", | ||
| ): | ||
| self.spark.sql( | ||
| "SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id" | ||
| ).show() | ||
|
|
||
| with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): | ||
| self.spark.sql( | ||
| "SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id" | ||
| ).show() | ||
|
|
||
| def test_named_arguments_and_defaults(self): | ||
| df = self.data | ||
|
|
||
| @pandas_udf("double", PandasUDFType.GROUPED_AGG) | ||
| def biased_sum(v, w=None): | ||
| return v.sum() + (w.sum() if w is not None else 100) | ||
|
|
||
| with self.tempView("v"): | ||
| df.createOrReplaceTempView("v") | ||
| self.spark.udf.register("biased_sum", biased_sum) | ||
|
|
||
| # without "w" | ||
| for i, aggregated in enumerate( | ||
| [ | ||
| df.groupby("id").agg(biased_sum(df.v).alias("s")), | ||
| df.groupby("id").agg(biased_sum(v=df.v).alias("s")), | ||
| self.spark.sql("SELECT id, biased_sum(v) as s FROM v GROUP BY id"), | ||
| self.spark.sql("SELECT id, biased_sum(v => v) as s FROM v GROUP BY id"), | ||
| ] | ||
| ): | ||
| with self.subTest(with_w=False, query_no=i): | ||
| assertDataFrameEqual( | ||
| aggregated, df.groupby("id").agg((sum(df.v) + lit(100)).alias("s")) | ||
| ) | ||
|
|
||
| # with "w" | ||
| for i, aggregated in enumerate( | ||
| [ | ||
| df.groupby("id").agg(biased_sum(df.v, w=df.w).alias("s")), | ||
| df.groupby("id").agg(biased_sum(v=df.v, w=df.w).alias("s")), | ||
| df.groupby("id").agg(biased_sum(w=df.w, v=df.v).alias("s")), | ||
| self.spark.sql("SELECT id, biased_sum(v, w => w) as s FROM v GROUP BY id"), | ||
| self.spark.sql("SELECT id, biased_sum(v => v, w => w) as s FROM v GROUP BY id"), | ||
| self.spark.sql("SELECT id, biased_sum(w => w, v => v) as s FROM v GROUP BY id"), | ||
| ] | ||
| ): | ||
| with self.subTest(with_w=True, query_no=i): | ||
| assertDataFrameEqual( | ||
| aggregated, df.groupby("id").agg((sum(df.v) + sum(df.w)).alias("s")) | ||
| ) | ||
|
|
||
|
|
||
| class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, ReusedSQLTestCase): | ||
| pass | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you please also port these negative tests to the 'kwargs' Pandas UDF case as well? It would be good to make sure we do the same checks there. Same for the window functions testing.