Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15654,7 +15654,7 @@ def udtf(
+---+---+---+

>>> _ = spark.udtf.register("test_udtf", TestUDTFWithKwargs)
>>> spark.sql("SELECT * FROM test_udtf(1, x=>'x', b=>'b')").show()
>>> spark.sql("SELECT * FROM test_udtf(1, x => 'x', b => 'b')").show()
+---+---+---+
| a| b| x|
+---+---+---+
Expand Down
88 changes: 79 additions & 9 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,8 +1806,8 @@ def eval(self, a, b):

for i, df in enumerate(
[
self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'x')"),
self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => 10)"),
TestUDTF(a=lit(10), b=lit("x")),
TestUDTF(b=lit("x"), a=lit(10)),
]
Expand All @@ -1827,15 +1827,15 @@ def eval(self, a, b):
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql("SELECT * FROM test_udtf(a=>10, a=>100)").show()
self.spark.sql("SELECT * FROM test_udtf(a => 10, a => 100)").show()

with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql("SELECT * FROM test_udtf(a=>10, 'x')").show()
self.spark.sql("SELECT * FROM test_udtf(a => 10, 'x')").show()

with self.assertRaisesRegex(
PythonException, r"eval\(\) got an unexpected keyword argument 'c'"
):
self.spark.sql("SELECT * FROM test_udtf(c=>'x')").show()
self.spark.sql("SELECT * FROM test_udtf(c => 'x')").show()

def test_udtf_with_kwargs(self):
@udtf(returnType="a: int, b: string")
Expand All @@ -1847,8 +1847,8 @@ def eval(self, **kwargs):

for i, df in enumerate(
[
self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'x')"),
self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => 10)"),
TestUDTF(a=lit(10), b=lit("x")),
TestUDTF(b=lit("x"), a=lit(10)),
]
Expand All @@ -1874,15 +1874,85 @@ def eval(self, **kwargs):

for i, df in enumerate(
[
self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'x')"),
self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => 10)"),
TestUDTF(a=lit(10), b=lit("x")),
TestUDTF(b=lit("x"), a=lit(10)),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(a=10, b="x")])

def test_udtf_with_named_arguments_lateral_join(self):
@udtf
class TestUDTF:
@staticmethod
def analyze(a, b):
return AnalyzeResult(StructType().add("a", a.data_type))

def eval(self, a, b):
yield a,

self.spark.udtf.register("test_udtf", TestUDTF)

# lateral join
for i, df in enumerate(
[
self.spark.sql(
"SELECT f.* FROM "
"VALUES (0, 'x'), (1, 'y') t(a, b), LATERAL test_udtf(a => a, b => b) f"
),
self.spark.sql(
"SELECT f.* FROM "
"VALUES (0, 'x'), (1, 'y') t(a, b), LATERAL test_udtf(b => b, a => a) f"
),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(a=0), Row(a=1)])

def test_udtf_with_named_arguments_and_defaults(self):
@udtf
class TestUDTF:
@staticmethod
def analyze(a, b=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If b is a parameter with a default value (e.g 100), do we need to also assign a default value (e.g. None) here in the analyze method? If we directly use def analyze(a, b) does it throw an exception?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if b here doesn't have a default value, it will raise an exception.

Both eval and analyze will be called with the same argument list.
So if eval takes a, analyze should also take a, if eval takes a, b, analyze should also take a, b, and so on.

It can be AnalyzeArgument with the correct parameters, or take **kwargs.

def analyze(a, b=AnalyzeArgument(StringType(), None, False)):
def analyze(a, **kwargs):

but taking AnalyzeArgument as default is not recommended if the data_type is a struct type, which could cause unexpected behavior similar to taking [] or {} as default.

schema = StructType().add("a", a.data_type)
if b is None:
return AnalyzeResult(schema.add("b", IntegerType()))
else:
return AnalyzeResult(schema.add("b", b.data_type))

def eval(self, a, b=100):
yield a, b

self.spark.udtf.register("test_udtf", TestUDTF)

# without "b"
for i, df in enumerate(
[
self.spark.sql("SELECT * FROM test_udtf(10)"),
self.spark.sql("SELECT * FROM test_udtf(a => 10)"),
TestUDTF(lit(10)),
TestUDTF(a=lit(10)),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(a=10, b=100)])

# with "b"
for i, df in enumerate(
[
self.spark.sql("SELECT * FROM test_udtf(10, b => 'z')"),
self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'z')"),
self.spark.sql("SELECT * FROM test_udtf(b => 'z', a => 10)"),
TestUDTF(lit(10), b=lit("z")),
TestUDTF(a=lit(10), b=lit("z")),
TestUDTF(b=lit("z"), a=lit(10)),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(a=10, b="z")])


class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down