-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-27803][SQL][PYTHON] Fix column pruning for Python UDF #24675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.plans.logical | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} | ||
| import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF} | ||
|
|
||
| /** | ||
| * FlatMap groups using an udf: pandas.Dataframe -> pandas.DataFrame. | ||
|
|
@@ -38,3 +38,30 @@ case class FlatMapGroupsInPandas( | |
| */ | ||
| override val producedAttributes = AttributeSet(output) | ||
| } | ||
|
|
||
| trait BaseEvalPython extends UnaryNode { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a problem I want to address later. I think More specifically, if |
||
|
|
||
| def udfs: Seq[PythonUDF] | ||
|
|
||
| def resultAttrs: Seq[Attribute] | ||
|
|
||
| override def output: Seq[Attribute] = child.output ++ resultAttrs | ||
|
||
|
|
||
| override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references)) | ||
|
||
| } | ||
|
|
||
| /** | ||
| * A logical plan that evaluates a [[PythonUDF]] | ||
| */ | ||
| case class BatchEvalPython( | ||
| udfs: Seq[PythonUDF], | ||
| resultAttrs: Seq[Attribute], | ||
| child: LogicalPlan) extends BaseEvalPython | ||
|
|
||
| /** | ||
| * A logical plan that evaluates a [[PythonUDF]] with Apache Arrow. | ||
| */ | ||
| case class ArrowEvalPython( | ||
| udfs: Seq[PythonUDF], | ||
| resultAttrs: Seq[Attribute], | ||
| child: LogicalPlan) extends BaseEvalPython | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution | |
|
|
||
| import org.apache.spark.sql.ExperimentalMethods | ||
| import org.apache.spark.sql.catalyst.catalog.SessionCatalog | ||
| import org.apache.spark.sql.catalyst.optimizer.Optimizer | ||
| import org.apache.spark.sql.catalyst.optimizer.{ColumnPruning, Optimizer, PushDownPredicate, RemoveNoopOperators} | ||
| import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions | ||
| import org.apache.spark.sql.execution.datasources.SchemaPruning | ||
| import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs} | ||
|
|
@@ -32,14 +32,21 @@ class SparkOptimizer( | |
| override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ | ||
| Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ | ||
| Batch("Extract Python UDFs", Once, | ||
| Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs): _*) :+ | ||
| ExtractPythonUDFFromAggregate, | ||
| ExtractPythonUDFs, | ||
| // The eval-python node may be between Project/Filter and the scan node, which breaks | ||
| // column pruning and filter push-down. Here we rerun the related optimizer rules. | ||
| ColumnPruning, | ||
| PushDownPredicate, | ||
| RemoveNoopOperators) :+ | ||
| Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ | ||
| Batch("Schema Pruning", Once, SchemaPruning)) ++ | ||
| postHocOptimizationBatches :+ | ||
| Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) | ||
|
|
||
| override def nonExcludableRules: Seq[String] = | ||
| super.nonExcludableRules :+ ExtractPythonUDFFromAggregate.ruleName | ||
| override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+ | ||
| ExtractPythonUDFFromAggregate.ruleName :+ | ||
| ExtractPythonUDFs.ruleName | ||
|
||
|
|
||
| /** | ||
| * Optimization batches that are executed before the regular optimization batches (also before | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,25 +25,14 @@ import org.apache.spark.TaskContext | |
| import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} | ||
| import org.apache.spark.sql.execution.SparkPlan | ||
| import org.apache.spark.sql.types.{StructField, StructType} | ||
|
|
||
| /** | ||
| * A logical plan that evaluates a [[PythonUDF]] | ||
| */ | ||
| case class BatchEvalPython( | ||
|
||
| udfs: Seq[PythonUDF], | ||
| output: Seq[Attribute], | ||
| child: LogicalPlan) extends UnaryNode { | ||
| override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) | ||
| } | ||
|
|
||
| /** | ||
| * A physical plan that evaluates a [[PythonUDF]] | ||
| */ | ||
| case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) | ||
| extends EvalPythonExec(udfs, output, child) { | ||
| case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan) | ||
| extends EvalPythonExec(udfs, resultAttrs, child) { | ||
|
|
||
| protected override def evaluate( | ||
| funcs: Seq[ChainedPythonFunctions], | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -158,21 +158,9 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { | |
| // If there aren't any, we are done. | ||
| plan | ||
| } else { | ||
| val inputsForPlan = plan.references ++ plan.outputSet | ||
| val prunedChildren = plan.children.map { child => | ||
| val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq | ||
| if (allNeededOutput.length != child.output.length) { | ||
| Project(allNeededOutput, child) | ||
| } else { | ||
| child | ||
| } | ||
| } | ||
| val planWithNewChildren = plan.withNewChildren(prunedChildren) | ||
|
|
||
| val attributeMap = mutable.HashMap[PythonUDF, Expression]() | ||
| val splitFilter = trySplitFilter(planWithNewChildren) | ||
| // Rewrite the child that has the input required for the UDF | ||
| val newChildren = splitFilter.children.map { child => | ||
| val newChildren = plan.children.map { child => | ||
| // Pick the UDF we are going to evaluate | ||
| val validUdfs = udfs.filter { udf => | ||
| // Check to make sure that the UDF can be evaluated with only the input of this child. | ||
|
|
@@ -191,9 +179,9 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { | |
| _.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF | ||
| ) match { | ||
| case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => | ||
| ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child) | ||
| ArrowEvalPython(vectorizedUdfs, resultAttrs, child) | ||
| case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => | ||
| BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child) | ||
| BatchEvalPython(plainUdfs, resultAttrs, child) | ||
| case _ => | ||
| throw new AnalysisException( | ||
| "Expected either Scalar Pandas UDFs or Batched UDFs but got both") | ||
|
|
@@ -211,7 +199,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { | |
| sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") | ||
| } | ||
|
|
||
| val rewritten = splitFilter.withNewChildren(newChildren).transformExpressions { | ||
| val rewritten = plan.withNewChildren(newChildren).transformExpressions { | ||
| case p: PythonUDF if attributeMap.contains(p) => | ||
| attributeMap(p) | ||
| } | ||
|
|
@@ -226,22 +214,4 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { | |
| } | ||
| } | ||
| } | ||
|
|
||
| // Split the original FilterExec to two FilterExecs. Only push down the first few predicates | ||
| // that are all deterministic. | ||
| private def trySplitFilter(plan: LogicalPlan): LogicalPlan = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain a little why this is no longer needed?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. quote from the PR description
|
||
| plan match { | ||
| case filter: Filter => | ||
| val (candidates, nonDeterministic) = | ||
| splitConjunctivePredicates(filter.condition).partition(_.deterministic) | ||
| val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_)) | ||
| if (pushDown.nonEmpty) { | ||
| val newChild = Filter(pushDown.reduceLeft(And), filter.child) | ||
| Filter((rest ++ nonDeterministic).reduceLeft(And), newChild) | ||
| } else { | ||
| filter | ||
| } | ||
| case o => o | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my benefit, would you mind explain what does
canPushThroughdefine? Are these nodes that a projection and/or filter can be pushed through?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This defines the nodes that we can push filters through.