-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-18766] [SQL] Push Down Filter Through BatchEvalPython (Python UDF) #16193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
eaf740a
2c3b917
3d9ba67
6586c90
b60f7bb
04b0e9c
2c8e593
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression | |
| import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.execution | ||
| import org.apache.spark.sql.execution.SparkPlan | ||
| import org.apache.spark.sql.execution.{FilterExec, SparkPlan} | ||
|
|
||
|
|
||
| /** | ||
|
|
@@ -90,7 +90,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { | |
| * This has the limitation that the input to the Python UDF is not allowed include attributes from | ||
| * multiple child operators. | ||
| */ | ||
| object ExtractPythonUDFs extends Rule[SparkPlan] { | ||
| object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { | ||
|
|
||
| private def hasPythonUDF(e: Expression): Boolean = { | ||
| e.find(_.isInstanceOf[PythonUDF]).isDefined | ||
|
|
@@ -126,10 +126,11 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { | |
| plan | ||
| } else { | ||
| val attributeMap = mutable.HashMap[PythonUDF, Expression]() | ||
| val splitFilter = trySplitFilter(plan) | ||
| // Rewrite the child that has the input required for the UDF | ||
| val newChildren = plan.children.map { child => | ||
| val newChildren = splitFilter.children.map { child => | ||
| // Pick the UDF we are going to evaluate | ||
| val validUdfs = udfs.filter { case udf => | ||
| val validUdfs = udfs.filter { udf => | ||
| // Check to make sure that the UDF can be evaluated with only the input of this child. | ||
| udf.references.subsetOf(child.outputSet) | ||
| }.toArray // Turn it into an array since iterators cannot be serialized in Scala 2.10 | ||
|
|
@@ -150,7 +151,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { | |
| sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") | ||
| } | ||
|
|
||
| val rewritten = plan.withNewChildren(newChildren).transformExpressions { | ||
| val rewritten = splitFilter.withNewChildren(newChildren).transformExpressions { | ||
| case p: PythonUDF if attributeMap.contains(p) => | ||
| attributeMap(p) | ||
| } | ||
|
|
@@ -165,4 +166,22 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { | |
| } | ||
| } | ||
| } | ||
|
|
||
| // Split the original FilterExec to two FilterExecs. Only push down the first few predicates | ||
| // that are all deterministic. | ||
| private def trySplitFilter(plan: SparkPlan): SparkPlan = { | ||
| plan match { | ||
| case filter: FilterExec => | ||
| val (candidates, containingNonDeterministic) = | ||
| splitConjunctivePredicates(filter.condition).span(_.deterministic) | ||
| val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will change the semantics. Let me write a comment to explain |
||
| if (pushDown.nonEmpty) { | ||
| val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) | ||
| FilterExec((rest ++ containingNonDeterministic).reduceLeft(And), newChild) | ||
| } else { | ||
| filter | ||
| } | ||
| case o => o | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.execution.python | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.api.python.PythonFunction | ||
| import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In} | ||
| import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec} | ||
| import org.apache.spark.sql.test.SharedSQLContext | ||
| import org.apache.spark.sql.types.BooleanType | ||
|
|
||
| class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { | ||
| import testImplicits.newProductEncoder | ||
| import testImplicits.localSeqToDatasetHolder | ||
|
|
||
| override def beforeAll(): Unit = { | ||
| super.beforeAll() | ||
| spark.udf.registerPython("dummyPythonUDF", new MyDummyPythonUDF) | ||
| } | ||
|
|
||
| override def afterAll(): Unit = { | ||
| spark.sessionState.functionRegistry.dropFunction("dummyPythonUDF") | ||
| super.afterAll() | ||
| } | ||
|
|
||
| test("Python UDF: push down deterministic FilterExec predicates") { | ||
| val df = Seq(("Hello", 4)).toDF("a", "b") | ||
| .where("dummyPythonUDF(b) and dummyPythonUDF(a) and a in (3, 4)") | ||
| val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { | ||
| case f @ FilterExec( | ||
| And(_: AttributeReference, _: AttributeReference), | ||
| InputAdapter(_: BatchEvalPythonExec)) => f | ||
| case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b | ||
| } | ||
| assert(qualifiedPlanNodes.size == 2) | ||
| } | ||
|
|
||
| test("Nested Python UDF: push down deterministic FilterExec predicates") { | ||
| val df = Seq(("Hello", 4)).toDF("a", "b") | ||
| .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)") | ||
| val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { | ||
| case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec)) => f | ||
| case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b | ||
| } | ||
| assert(qualifiedPlanNodes.size == 2) | ||
| } | ||
|
|
||
| test("Python UDF: no push down on non-deterministic") { | ||
| val df = Seq(("Hello", 4)).toDF("a", "b") | ||
| .where("b > 4 and dummyPythonUDF(a) and rand() > 3") | ||
| val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { | ||
| case f @ FilterExec( | ||
| And(_: AttributeReference, _: GreaterThan), | ||
| InputAdapter(_: BatchEvalPythonExec)) => f | ||
| case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b | ||
| } | ||
| assert(qualifiedPlanNodes.size == 2) | ||
| } | ||
|
|
||
| test("Python UDF: no push down on predicates starting from the first non-deterministic") { | ||
| val df = Seq(("Hello", 4)).toDF("a", "b") | ||
| .where("dummyPythonUDF(a) and rand() > 3 and b > 4") | ||
| val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { | ||
| case f @ FilterExec(And(_: And, _: GreaterThan), InputAdapter(_: BatchEvalPythonExec)) => f | ||
| } | ||
| assert(qualifiedPlanNodes.size == 1) | ||
| } | ||
|
|
||
| test("Python UDF refers to the attributes from more than one child") { | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test case is not directly related to this PR. In the future, we need to add more unit test cases in Scala side for verifying |
||
| val df = Seq(("Hello", 4)).toDF("a", "b") | ||
| val df2 = Seq(("Hello", 4)).toDF("c", "d") | ||
| val joinDF = df.join(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)") | ||
|
|
||
| val e = intercept[RuntimeException] { | ||
| joinDF.queryExecution.executedPlan | ||
| }.getMessage | ||
| assert(Seq("Invalid PythonUDF dummyUDF", "requires attributes from more than one child") | ||
| .forall(e.contains)) | ||
| } | ||
| } | ||
|
|
||
| // This Python UDF is dummy and just for testing. Unable to execute. | ||
| class DummyUDF extends PythonFunction( | ||
| command = Array[Byte](), | ||
| envVars = Map("" -> "").asJava, | ||
| pythonIncludes = ArrayBuffer("").asJava, | ||
| pythonExec = "", | ||
| pythonVer = "", | ||
| broadcastVars = null, | ||
| accumulator = null) | ||
|
|
||
| class MyDummyPythonUDF | ||
| extends UserDefinedPythonFunction(name = "dummyUDF", func = new DummyUDF, dataType = BooleanType) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this test fail before this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. This case works well.