diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 06985ac85b70e..02d5a1f27aa7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -116,6 +116,10 @@ trait PredicateHelper { // non-correlated subquery will be replaced as literal e.children.isEmpty case a: AttributeReference => true + // PythonUDF will be executed by dedicated physical operator later. + // For PythonUDFs that can't be evaluated in join condition, `PullOutPythonUDFInJoinCondition` + // will pull them out later. + case _: PythonUDF => true case e: Unevaluable => false case e => e.children.forall(canEvaluateWithinJoin) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 2db4667fd0561..3ec8d18bc871d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -24,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{BooleanType, IntegerType} import org.apache.spark.unsafe.types.CalendarInterval class FilterPushdownSuite extends PlanTest { @@ -41,9 +42,14 @@ class FilterPushdownSuite extends PlanTest { CollapseProject) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val attrA = 'a.int + val attrB = 'b.int + val attrC = 'c.int + val attrD = 'd.int - val testRelation1 = LocalRelation('d.int) + val testRelation = LocalRelation(attrA, attrB, attrC) + + val testRelation1 = LocalRelation(attrD) // This test already passes. test("eliminate subqueries") { @@ -1202,4 +1208,26 @@ class FilterPushdownSuite extends PlanTest { comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, checkAnalysis = false) } + + test("SPARK-28345: PythonUDF predicate should be able to pushdown to join") { + val pythonUDFJoinCond = { + val pythonUDF = PythonUDF("pythonUDF", null, + IntegerType, + Seq(attrA), + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + pythonUDF === attrD + } + + val query = testRelation.join( + testRelation1, + joinType = Cross).where(pythonUDFJoinCond) + + val expected = testRelation.join( + testRelation1, + joinType = Cross, + condition = Some(pythonUDFJoinCond)).analyze + + comparePlans(Optimize.execute(query.analyze), expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 32cddc94166b7..503168ec69d54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -26,7 +26,8 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} -import org.apache.spark.sql.execution.{BinaryExecNode, SortExec} +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.BatchEvalPythonExec import org.apache.spark.sql.internal.SQLConf @@ -994,4 +995,26 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(1, 2, 1, 2) :: Nil) } + + test("SPARK-28345: PythonUDF predicate should be able to pushdown to join") { + import IntegratedUDFTestUtils._ + + assume(shouldTestPythonUDFs) + + val pythonTestUDF = TestPythonUDF(name = "udf") + + val left = Seq((1, 2), (2, 3)).toDF("a", "b") + val right = Seq((1, 2), (3, 4)).toDF("c", "d") + val df = left.crossJoin(right).where(pythonTestUDF($"a") === pythonTestUDF($"c")) + + // Before optimization, there is a logical Filter operator. + val filterInAnalysis = df.queryExecution.analyzed.find(_.isInstanceOf[Filter]) + assert(filterInAnalysis.isDefined) + + // Filter predicate was pushdown as join condition. So there is no Filter exec operator. + val filterExec = df.queryExecution.executedPlan.find(_.isInstanceOf[FilterExec]) + assert(filterExec.isEmpty) + + checkAnswer(df, Row(1, 2, 1, 2) :: Nil) + } }