diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9cc364cc1f8c..cf88932d1ec1 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -15654,7 +15654,7 @@ def udtf( +---+---+---+ >>> _ = spark.udtf.register("test_udtf", TestUDTFWithKwargs) - >>> spark.sql("SELECT * FROM test_udtf(1, x=>'x', b=>'b')").show() + >>> spark.sql("SELECT * FROM test_udtf(1, x => 'x', b => 'b')").show() +---+---+---+ | a| b| x| +---+---+---+ diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index cd0604ccacee..aa7df815e81b 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -1806,8 +1806,8 @@ def eval(self, a, b): for i, df in enumerate( [ - self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"), - self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"), + self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'x')"), + self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => 10)"), TestUDTF(a=lit(10), b=lit("x")), TestUDTF(b=lit("x"), a=lit(10)), ] @@ -1827,15 +1827,15 @@ def eval(self, a, b): AnalysisException, "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", ): - self.spark.sql("SELECT * FROM test_udtf(a=>10, a=>100)").show() + self.spark.sql("SELECT * FROM test_udtf(a => 10, a => 100)").show() with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): - self.spark.sql("SELECT * FROM test_udtf(a=>10, 'x')").show() + self.spark.sql("SELECT * FROM test_udtf(a => 10, 'x')").show() with self.assertRaisesRegex( PythonException, r"eval\(\) got an unexpected keyword argument 'c'" ): - self.spark.sql("SELECT * FROM test_udtf(c=>'x')").show() + self.spark.sql("SELECT * FROM test_udtf(c => 'x')").show() def test_udtf_with_kwargs(self): @udtf(returnType="a: int, b: string") @@ -1847,8 +1847,8 @@ def eval(self, **kwargs): for i, df in enumerate( [ - self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"), - self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"), + self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'x')"), + self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => 10)"), TestUDTF(a=lit(10), b=lit("x")), TestUDTF(b=lit("x"), a=lit(10)), ] @@ -1874,8 +1874,8 @@ def eval(self, **kwargs): for i, df in enumerate( [ - self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"), - self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"), + self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'x')"), + self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => 10)"), TestUDTF(a=lit(10), b=lit("x")), TestUDTF(b=lit("x"), a=lit(10)), ] @@ -1883,6 +1883,76 @@ def eval(self, **kwargs): with self.subTest(query_no=i): assertDataFrameEqual(df, [Row(a=10, b="x")]) + def test_udtf_with_named_arguments_lateral_join(self): + @udtf + class TestUDTF: + @staticmethod + def analyze(a, b): + return AnalyzeResult(StructType().add("a", a.data_type)) + + def eval(self, a, b): + yield a, + + self.spark.udtf.register("test_udtf", TestUDTF) + + # lateral join + for i, df in enumerate( + [ + self.spark.sql( + "SELECT f.* FROM " + "VALUES (0, 'x'), (1, 'y') t(a, b), LATERAL test_udtf(a => a, b => b) f" + ), + self.spark.sql( + "SELECT f.* FROM " + "VALUES (0, 'x'), (1, 'y') t(a, b), LATERAL test_udtf(b => b, a => a) f" + ), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(df, [Row(a=0), Row(a=1)]) + + def test_udtf_with_named_arguments_and_defaults(self): + @udtf + class TestUDTF: + @staticmethod + def analyze(a, b=None): + schema = StructType().add("a", a.data_type) + if b is None: + return AnalyzeResult(schema.add("b", IntegerType())) + else: + return AnalyzeResult(schema.add("b", b.data_type)) + + def eval(self, a, b=100): + yield a, b + + self.spark.udtf.register("test_udtf", TestUDTF) + + # without "b" + for i, df in enumerate( + [ + self.spark.sql("SELECT * FROM test_udtf(10)"), + self.spark.sql("SELECT * FROM test_udtf(a => 10)"), + TestUDTF(lit(10)), + TestUDTF(a=lit(10)), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(df, [Row(a=10, b=100)]) + + # with "b" + for i, df in enumerate( + [ + self.spark.sql("SELECT * FROM test_udtf(10, b => 'z')"), + self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'z')"), + self.spark.sql("SELECT * FROM test_udtf(b => 'z', a => 10)"), + TestUDTF(lit(10), b=lit("z")), + TestUDTF(a=lit(10), b=lit("z")), + TestUDTF(b=lit("z"), a=lit(10)), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(df, [Row(a=10, b="z")]) + class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase): @classmethod