diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 0dafa18743fa..803d471c8c6f 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -197,6 +197,8 @@ def test_udf_in_join_condition(self): left = self.spark.createDataFrame([Row(a=1)]) right = self.spark.createDataFrame([Row(b=1)]) f = udf(lambda a, b: a == b, BooleanType()) + # The udf uses attributes from both sides of join, so it is pulled out as Filter + + # Cross join. df = left.join(right, f("a", "b")) with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): df.collect() @@ -243,6 +245,14 @@ def runWithJoinType(join_type, type_string): runWithJoinType("leftanti", "LeftAnti") runWithJoinType("leftsemi", "LeftSemi") + def test_udf_as_join_condition(self): + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) + f = udf(lambda a: a, IntegerType()) + + df = left.join(right, [f("a") == f("b"), left.a1 == right.b1]) + self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)]) + def test_udf_without_arguments(self): self.spark.catalog.registerFunction("foo", lambda: "bar") [row] = self.spark.sql("SELECT foo()").collect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 58fe7d5cbc0f..fc4ded376bf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -179,7 +179,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { validUdfs.forall(PythonUDF.isScalarPythonUDF), "Can only extract scalar vectorized udf or sql batch udf") - val resultAttrs = udfs.zipWithIndex.map { case (u, i) => + val resultAttrs = validUdfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 38c634e1107b..32cddc94166b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} import org.apache.spark.sql.execution.{BinaryExecNode, SortExec} import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.python.BatchEvalPythonExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -969,4 +970,28 @@ class JoinSuite extends QueryTest with SharedSQLContext { Seq(Row(0.0d, 0.0/0.0))))) } } + + test("SPARK-28323: PythonUDF should be able to use in join condition") { + import IntegratedUDFTestUtils._ + + assume(shouldTestPythonUDFs) + + val pythonTestUDF = TestPythonUDF(name = "udf") + + val left = Seq((1, 2), (2, 3)).toDF("a", "b") + val right = Seq((1, 2), (3, 4)).toDF("c", "d") + val df = left.join(right, pythonTestUDF($"a") === pythonTestUDF($"c")) + + val joinNode = df.queryExecution.executedPlan.find(_.isInstanceOf[BroadcastHashJoinExec]) + assert(joinNode.isDefined) + + // There are two PythonUDFs which use attribute from left and right of join, individually. + // So two PythonUDFs should be evaluated before the join operator, at left and right side. + val pythonEvals = joinNode.get.collect { + case p: BatchEvalPythonExec => p + } + assert(pythonEvals.size == 2) + + checkAnswer(df, Row(1, 2, 1, 2) :: Nil) + } }