-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-44749][SQL][PYTHON] Support named arguments in Python UDTF #42422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cfa7de4
df5c7fa
cdc1452
b9eab18
442c934
bf79e46
f1f8594
8700ab5
f10c90c
587a970
eb4a2dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1565,6 +1565,7 @@ def eval(self, **kwargs): | |
| expected = [Row(c1="hello", c2="world")] | ||
| assertDataFrameEqual(TestUDTF(), expected) | ||
| assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf()"), expected) | ||
| assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf(a=>1)"), expected) | ||
|
|
||
| with self.assertRaisesRegex( | ||
| AnalysisException, r"analyze\(\) takes 0 positional arguments but 1 was given" | ||
|
|
@@ -1795,6 +1796,93 @@ def terminate(self): | |
| assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) | ||
| assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) | ||
|
|
||
| def test_udtf_with_named_arguments(self): | ||
| @udtf(returnType="a: int") | ||
| class TestUDTF: | ||
| def eval(self, a, b): | ||
| yield a, | ||
|
|
||
| self.spark.udtf.register("test_udtf", TestUDTF) | ||
|
|
||
| for i, df in enumerate( | ||
| [ | ||
| self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"), | ||
| self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"), | ||
| TestUDTF(a=lit(10), b=lit("x")), | ||
| TestUDTF(b=lit("x"), a=lit(10)), | ||
| ] | ||
| ): | ||
| with self.subTest(query_no=i): | ||
| assertDataFrameEqual(df, [Row(a=10)]) | ||
|
|
||
| def test_udtf_with_named_arguments_negative(self): | ||
| @udtf(returnType="a: int") | ||
| class TestUDTF: | ||
| def eval(self, a, b): | ||
ueshin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| yield a, | ||
|
|
||
| self.spark.udtf.register("test_udtf", TestUDTF) | ||
|
|
||
| with self.assertRaisesRegex( | ||
| AnalysisException, | ||
| "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", | ||
| ): | ||
| self.spark.sql("SELECT * FROM test_udtf(a=>10, a=>100)").show() | ||
|
|
||
| with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): | ||
| self.spark.sql("SELECT * FROM test_udtf(a=>10, 'x')").show() | ||
|
|
||
| with self.assertRaisesRegex( | ||
| PythonException, r"eval\(\) got an unexpected keyword argument 'c'" | ||
| ): | ||
| self.spark.sql("SELECT * FROM test_udtf(c=>'x')").show() | ||
|
|
||
| def test_udtf_with_kwargs(self): | ||
| @udtf(returnType="a: int, b: string") | ||
| class TestUDTF: | ||
| def eval(self, **kwargs): | ||
| yield kwargs["a"], kwargs["b"] | ||
|
|
||
| self.spark.udtf.register("test_udtf", TestUDTF) | ||
|
|
||
| for i, df in enumerate( | ||
| [ | ||
| self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"), | ||
| self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would be the error message if the named argument is used incorrectly? For example
I am afraid that if we directly leverage Python's kwargs, the error messages wouldn't be as user-friendly as the SQL function ones.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good point. So far just rely on the Python's error. @dtenedor What's the error message like when applying name arguments with the above cases to other functions? Are there any example we can follow here?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I believe @learningchess2003 added these checks in [1]. They are currently in the [1] #42020
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to raise the following errors:
It will be checked in the analysis phase and an error with the error class
It will be handled in Python runtime and an error will be raised.
It will be checked in the analysis phase and an error with the error class
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LGTM! |
||
| TestUDTF(a=lit(10), b=lit("x")), | ||
| TestUDTF(b=lit("x"), a=lit(10)), | ||
| ] | ||
| ): | ||
| with self.subTest(query_no=i): | ||
| assertDataFrameEqual(df, [Row(a=10, b="x")]) | ||
|
|
||
| def test_udtf_with_analyze_kwargs(self): | ||
| @udtf | ||
| class TestUDTF: | ||
| @staticmethod | ||
| def analyze(**kwargs: AnalyzeArgument) -> AnalyzeResult: | ||
| return AnalyzeResult( | ||
| StructType( | ||
| [StructField(key, arg.data_type) for key, arg in sorted(kwargs.items())] | ||
| ) | ||
| ) | ||
|
|
||
| def eval(self, **kwargs): | ||
| yield tuple(value for _, value in sorted(kwargs.items())) | ||
|
|
||
| self.spark.udtf.register("test_udtf", TestUDTF) | ||
|
|
||
| for i, df in enumerate( | ||
| [ | ||
| self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"), | ||
| self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"), | ||
ueshin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| TestUDTF(a=lit(10), b=lit("x")), | ||
| TestUDTF(b=lit("x"), a=lit(10)), | ||
| ] | ||
| ): | ||
| with self.subTest(query_no=i): | ||
| assertDataFrameEqual(df, [Row(a=10, b="x")]) | ||
|
|
||
|
|
||
| class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase): | ||
| @classmethod | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.