From eaf740a1b373e3ccd22c53cc1ff9f7f943e7043d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 6 Dec 2016 23:45:01 -0800 Subject: [PATCH 1/7] fix. --- python/pyspark/sql/tests.py | 9 ++ .../execution/python/ExtractPythonUDFs.scala | 49 ++++++++- .../python/BatchEvalPythonExecSuite.scala | 99 +++++++++++++++++++ 3 files changed, 155 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0aff9cebe91b..857bbb5a3746 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -360,6 +360,15 @@ def test_broadcast_in_udf(self): [res] = self.spark.sql("SELECT MYUDF('')").collect() self.assertEqual("", res[0]) + def test_udf_with_filter_function(self): + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql.functions import udf, col + from pyspark.sql.types import BooleanType + + my_filter = udf(lambda a: a < 2, BooleanType()) + sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2")) + self.assertEqual(sel.collect(), [Row(key=1, value='1')]) + def test_udf_with_aggregate_function(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql.functions import udf, col, sum diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 16e44845d528..5dfecba0eac7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{FilterExec, SparkPlan} /** @@ -111,7 +111,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { } def apply(plan: SparkPlan): SparkPlan = plan transformUp { - case plan: SparkPlan => extract(plan) + case plan: SparkPlan => + val newPlan = extract(plan) + if (newPlan != plan) { + // Found and build BatchEvalPythonExec, and then push FilterExec + // through BatchEvalPythonExec + PushPredicateThroughBatchEvalPython.apply(newPlan) + } else { + plan + } } /** @@ -166,3 +174,40 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { } } } + +// This rule is to push deterministic predicates through BatchEvalPythonExec +object PushPredicateThroughBatchEvalPython extends Rule[SparkPlan] with PredicateHelper { + def apply(plan: SparkPlan): SparkPlan = plan transform { + case filter @ FilterExec(_, child: BatchEvalPythonExec) + if child.expressions.forall(_.deterministic) => + pushDownPredicate(filter, child.child) { predicate => + child.withNewChildren(Seq(FilterExec(predicate, child.child))) + } + } + + private def pushDownPredicate( + filter: FilterExec, + grandchild: SparkPlan)(insertFilter: Expression => SparkPlan): SparkPlan = { + // Only push down the predicates that is deterministic and all the referenced attributes + // come from grandchild. + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(filter.condition).span(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(grandchild.outputSet) + } + + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val newChild = insertFilter(pushDown.reduceLeft(And)) + if (stayUp.nonEmpty) { + FilterExec(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + filter + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala new file mode 100644 index 000000000000..c383810bae24 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.api.python.PythonFunction +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, In} +import org.apache.spark.sql.execution.{FilterExec, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.BooleanType + +class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.newProductEncoder + import testImplicits.localSeqToDatasetHolder + + override def beforeAll(): Unit = { + super.beforeAll() + spark.udf.registerPython("dummyPythonUDF", new MyDummyPythonUDF) + } + + override def afterAll(): Unit = { + spark.sessionState.functionRegistry.dropFunction("dummyPythonUDF") + super.afterAll() + } + + test("Python UDF: push down deterministic FilterExec predicates") { + val df = Seq(("Hello", 4)).toDF("a", "b") + .where("dummyPythonUDF(b) and dummyPythonUDF(a) and a in (3, 4)") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { + case f @ FilterExec(And(_: AttributeReference, _: AttributeReference), _) => f + case b: BatchEvalPythonExec => b + case f @ FilterExec(_: In, _) => f + } + assert(qualifiedPlanNodes.size == 3) + } + + test("Nested Python UDF: push down deterministic FilterExec predicates") { + val df = Seq(("Hello", 4)).toDF("a", "b") + .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { + case f @ FilterExec(_: AttributeReference, _) => f + case b: BatchEvalPythonExec => b + case f @ FilterExec(_: In, _) => f + } + assert(qualifiedPlanNodes.size == 4) + } + + test("Python UDF: no push down on non-deterministic FilterExec predicates") { + val df = Seq(("Hello", 4)).toDF("a", "b") + .where("dummyPythonUDF(a) and rand() > 3") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { + case f: FilterExec => f + case b: BatchEvalPythonExec => b + } + assert(qualifiedPlanNodes.size == 2) + } + + test("Python UDF refers to the attributes from more than one child") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = Seq(("Hello", 4)).toDF("c", "d") + val joinDF = df.join(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)") + + val e = intercept[RuntimeException] { + joinDF.queryExecution.executedPlan + }.getMessage + assert(Seq("Invalid PythonUDF dummyUDF", "requires attributes from more than one child") + .forall(e.contains)) + } +} + +// This Python UDF is dummy and just for testing. Unable to execute. +class DummyUDF extends PythonFunction( + command = Array[Byte](), + envVars = Map("" -> "").asJava, + pythonIncludes = ArrayBuffer("").asJava, + pythonExec = "", + pythonVer = "", + broadcastVars = null, + accumulator = null) + +class MyDummyPythonUDF + extends UserDefinedPythonFunction(name = "dummyUDF", func = new DummyUDF, dataType = BooleanType) From 2c3b91738fae8286525cabb24c386503a570448b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 7 Dec 2016 11:06:43 -0800 Subject: [PATCH 2/7] fix the indents. --- .../apache/spark/sql/execution/python/ExtractPythonUDFs.scala | 2 +- .../spark/sql/execution/python/BatchEvalPythonExecSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 5dfecba0eac7..8644e52ac3ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -191,7 +191,7 @@ object PushPredicateThroughBatchEvalPython extends Rule[SparkPlan] with Predicat // Only push down the predicates that is deterministic and all the referenced attributes // come from grandchild. val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(filter.condition).span(_.deterministic) + splitConjunctivePredicates(filter.condition).span(_.deterministic) val (pushDown, rest) = candidates.partition { cond => cond.references.subsetOf(grandchild.outputSet) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index c383810bae24..48f027298579 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.BooleanType class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.newProductEncoder - import testImplicits.localSeqToDatasetHolder + import testImplicits.newProductEncoder + import testImplicits.localSeqToDatasetHolder override def beforeAll(): Unit = { super.beforeAll() From 3d9ba67593ab7b3804b213f024a5bc7393f6b026 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 8 Dec 2016 12:51:41 -0800 Subject: [PATCH 3/7] address comments. --- .../execution/python/ExtractPythonUDFs.scala | 74 ++++++++----------- 1 file changed, 29 insertions(+), 45 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 8644e52ac3ce..b8f484ed1ea5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -90,7 +90,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -object ExtractPythonUDFs extends Rule[SparkPlan] { +object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { private def hasPythonUDF(e: Expression): Boolean = { e.find(_.isInstanceOf[PythonUDF]).isDefined @@ -111,15 +111,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { } def apply(plan: SparkPlan): SparkPlan = plan transformUp { - case plan: SparkPlan => - val newPlan = extract(plan) - if (newPlan != plan) { - // Found and build BatchEvalPythonExec, and then push FilterExec - // through BatchEvalPythonExec - PushPredicateThroughBatchEvalPython.apply(newPlan) - } else { - plan - } + case plan: SparkPlan => extract(plan) } /** @@ -134,8 +126,10 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { plan } else { val attributeMap = mutable.HashMap[PythonUDF, Expression]() + val splittedFilter = trySplitFilter(plan) // Rewrite the child that has the input required for the UDF - val newChildren = plan.children.map { child => + val newChildren = + splittedFilter.children.map { child => // Pick the UDF we are going to evaluate val validUdfs = udfs.filter { case udf => // Check to make sure that the UDF can be evaluated with only the input of this child. @@ -158,7 +152,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") } - val rewritten = plan.withNewChildren(newChildren).transformExpressions { + val rewritten = splittedFilter.withNewChildren(newChildren).transformExpressions { case p: PythonUDF if attributeMap.contains(p) => attributeMap(p) } @@ -173,41 +167,31 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { } } } -} - -// This rule is to push deterministic predicates through BatchEvalPythonExec -object PushPredicateThroughBatchEvalPython extends Rule[SparkPlan] with PredicateHelper { - def apply(plan: SparkPlan): SparkPlan = plan transform { - case filter @ FilterExec(_, child: BatchEvalPythonExec) - if child.expressions.forall(_.deterministic) => - pushDownPredicate(filter, child.child) { predicate => - child.withNewChildren(Seq(FilterExec(predicate, child.child))) - } - } - private def pushDownPredicate( - filter: FilterExec, - grandchild: SparkPlan)(insertFilter: Expression => SparkPlan): SparkPlan = { - // Only push down the predicates that is deterministic and all the referenced attributes - // come from grandchild. - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(filter.condition).span(_.deterministic) - - val (pushDown, rest) = candidates.partition { cond => - cond.references.subsetOf(grandchild.outputSet) - } - - val stayUp = rest ++ containingNonDeterministic + // Split the original FilterExec to two FilterExecs. The upper FilterExec only contains + // Python UDF and non-deterministic predicates. + private def trySplitFilter(plan: SparkPlan): SparkPlan = { + plan match { + case filter: FilterExec => + // Only push down the predicates that is deterministic and all the referenced attributes + // come from child. + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(filter.condition).span(_.deterministic) + val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) + if (stayUp.nonEmpty) { + FilterExec(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + filter + } - if (pushDown.nonEmpty) { - val newChild = insertFilter(pushDown.reduceLeft(And)) - if (stayUp.nonEmpty) { - FilterExec(stayUp.reduceLeft(And), newChild) - } else { - newChild - } - } else { - filter + case o => o } } } From 6586c901c573bc8d63056e3b34d6e2834f04a0e9 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 8 Dec 2016 15:33:52 -0800 Subject: [PATCH 4/7] cleanup --- .../sql/execution/python/ExtractPythonUDFs.scala | 10 ++-------- .../python/BatchEvalPythonExecSuite.scala | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index b8f484ed1ea5..a1ab3dc08858 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -177,20 +177,14 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { // come from child. val (candidates, containingNonDeterministic) = splitConjunctivePredicates(filter.condition).span(_.deterministic) + // Python UDF is always deterministic val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) - val stayUp = rest ++ containingNonDeterministic - if (pushDown.nonEmpty) { val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) - if (stayUp.nonEmpty) { - FilterExec(stayUp.reduceLeft(And), newChild) - } else { - newChild - } + FilterExec((rest ++ containingNonDeterministic).reduceLeft(And), newChild) } else { filter } - case o => o } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 48f027298579..8730de245728 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -62,9 +62,19 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { assert(qualifiedPlanNodes.size == 4) } - test("Python UDF: no push down on non-deterministic FilterExec predicates") { + test("Python UDF: no push down on non-deterministic") { val df = Seq(("Hello", 4)).toDF("a", "b") - .where("dummyPythonUDF(a) and rand() > 3") + .where("b > 4 and dummyPythonUDF(a) and rand() > 3") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { + case f: FilterExec => f + case b: BatchEvalPythonExec => b + } + assert(qualifiedPlanNodes.size == 3) + } + + test("Python UDF: no push down on predicates starting from the first non-deterministic") { + val df = Seq(("Hello", 4)).toDF("a", "b") + .where("dummyPythonUDF(a) and rand() > 3 and b > 4") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f: FilterExec => f case b: BatchEvalPythonExec => b From b60f7bbe299f597c06c6a6c6a0e498a4384c2108 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 9 Dec 2016 11:27:09 -0800 Subject: [PATCH 5/7] update the comments. --- .../apache/spark/sql/execution/python/ExtractPythonUDFs.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index a1ab3dc08858..da0b0f38e732 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -173,11 +173,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { private def trySplitFilter(plan: SparkPlan): SparkPlan = { plan match { case filter: FilterExec => - // Only push down the predicates that is deterministic and all the referenced attributes - // come from child. + // Only push down the first few predicates that are all deterministic val (candidates, containingNonDeterministic) = splitConjunctivePredicates(filter.condition).span(_.deterministic) - // Python UDF is always deterministic val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) if (pushDown.nonEmpty) { val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) From 04b0e9c739a32319d84a86e77ab8ef89ba81d959 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 9 Dec 2016 11:30:59 -0800 Subject: [PATCH 6/7] update the comments. --- .../spark/sql/execution/python/ExtractPythonUDFs.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index da0b0f38e732..8ce6ee20aaf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -168,12 +168,11 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } } - // Split the original FilterExec to two FilterExecs. The upper FilterExec only contains - // Python UDF and non-deterministic predicates. + // Split the original FilterExec to two FilterExecs. Only push down the first few predicates + // that are all deterministic. private def trySplitFilter(plan: SparkPlan): SparkPlan = { plan match { case filter: FilterExec => - // Only push down the first few predicates that are all deterministic val (candidates, containingNonDeterministic) = splitConjunctivePredicates(filter.condition).span(_.deterministic) val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) From 2c8e593a2a705f536a284581f33c469574695015 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 9 Dec 2016 23:50:47 -0800 Subject: [PATCH 7/7] impove the test cases --- .../execution/python/ExtractPythonUDFs.scala | 9 +++-- .../python/BatchEvalPythonExecSuite.scala | 33 ++++++++++--------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 8ce6ee20aaf1..69b4b7bb07de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -126,12 +126,11 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { plan } else { val attributeMap = mutable.HashMap[PythonUDF, Expression]() - val splittedFilter = trySplitFilter(plan) + val splitFilter = trySplitFilter(plan) // Rewrite the child that has the input required for the UDF - val newChildren = - splittedFilter.children.map { child => + val newChildren = splitFilter.children.map { child => // Pick the UDF we are going to evaluate - val validUdfs = udfs.filter { case udf => + val validUdfs = udfs.filter { udf => // Check to make sure that the UDF can be evaluated with only the input of this child. udf.references.subsetOf(child.outputSet) }.toArray // Turn it into an array since iterators cannot be serialized in Scala 2.10 @@ -152,7 +151,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") } - val rewritten = splittedFilter.withNewChildren(newChildren).transformExpressions { + val rewritten = splitFilter.withNewChildren(newChildren).transformExpressions { case p: PythonUDF if attributeMap.contains(p) => attributeMap(p) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 8730de245728..81bea2fef8bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -21,8 +21,8 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, In} -import org.apache.spark.sql.execution.{FilterExec, SparkPlanTest} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In} +import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.BooleanType @@ -44,42 +44,43 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { val df = Seq(("Hello", 4)).toDF("a", "b") .where("dummyPythonUDF(b) and dummyPythonUDF(a) and a in (3, 4)") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { - case f @ FilterExec(And(_: AttributeReference, _: AttributeReference), _) => f - case b: BatchEvalPythonExec => b - case f @ FilterExec(_: In, _) => f + case f @ FilterExec( + And(_: AttributeReference, _: AttributeReference), + InputAdapter(_: BatchEvalPythonExec)) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b } - assert(qualifiedPlanNodes.size == 3) + assert(qualifiedPlanNodes.size == 2) } test("Nested Python UDF: push down deterministic FilterExec predicates") { val df = Seq(("Hello", 4)).toDF("a", "b") .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { - case f @ FilterExec(_: AttributeReference, _) => f - case b: BatchEvalPythonExec => b - case f @ FilterExec(_: In, _) => f + case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec)) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b } - assert(qualifiedPlanNodes.size == 4) + assert(qualifiedPlanNodes.size == 2) } test("Python UDF: no push down on non-deterministic") { val df = Seq(("Hello", 4)).toDF("a", "b") .where("b > 4 and dummyPythonUDF(a) and rand() > 3") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { - case f: FilterExec => f - case b: BatchEvalPythonExec => b + case f @ FilterExec( + And(_: AttributeReference, _: GreaterThan), + InputAdapter(_: BatchEvalPythonExec)) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b } - assert(qualifiedPlanNodes.size == 3) + assert(qualifiedPlanNodes.size == 2) } test("Python UDF: no push down on predicates starting from the first non-deterministic") { val df = Seq(("Hello", 4)).toDF("a", "b") .where("dummyPythonUDF(a) and rand() > 3 and b > 4") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { - case f: FilterExec => f - case b: BatchEvalPythonExec => b + case f @ FilterExec(And(_: And, _: GreaterThan), InputAdapter(_: BatchEvalPythonExec)) => f } - assert(qualifiedPlanNodes.size == 2) + assert(qualifiedPlanNodes.size == 1) } test("Python UDF refers to the attributes from more than one child") {