diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 3112b306c365..64f49e2d0d4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -89,7 +89,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( - python.ExtractPythonUDFs, PlanSubqueries(sparkSession), EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 969def762405..ab01b4f7e61b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaPruning -import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate +import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs} class SparkOptimizer( catalog: SessionCatalog, @@ -30,8 +30,9 @@ class SparkOptimizer( extends Optimizer(catalog) { override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ + Batch("Extract Python UDF", Once, + Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs)) :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ - Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ Batch("Parquet Schema Pruning", Once, ParquetSchemaPruning)) ++ postHocOptimizationBatches :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 75f5ec0e253d..f1f858e81105 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy +import org.apache.spark.sql.execution.python.PythonEvals import org.apache.spark.sql.internal.SQLConf class SparkPlanner( @@ -36,6 +37,7 @@ class SparkPlanner( override def strategies: Seq[Strategy] = experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( + PythonEvals :: DataSourceV2Strategy :: FileSourceStrategy :: DataSourceStrategy(conf) :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 0bc21c0986e6..6a03f860f8f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType @@ -57,7 +58,13 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int) } /** - * A physical plan that evaluates a [[PythonUDF]], + * A logical plan that evaluates a [[PythonUDF]]. + */ +case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan) + extends UnaryNode + +/** + * A physical plan that evaluates a [[PythonUDF]]. */ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) extends EvalPythonExec(udfs, output, child) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index f4d83e8dc7c2..2054c700957e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -25,9 +25,16 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.{StructField, StructType} +/** + * A logical plan that evaluates a [[PythonUDF]] + */ +case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan) + extends UnaryNode + /** * A physical plan that evaluates a [[PythonUDF]] */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index cb75874be32e..07341769898b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -21,12 +21,12 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.python.PythonEvalType -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, Strategy} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan /** @@ -93,7 +93,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { +object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { private type EvalType = Int private type EvalTypeChecker = EvalType => Boolean @@ -132,14 +132,14 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { expressions.flatMap(collectEvaluableUDFs) } - def apply(plan: SparkPlan): SparkPlan = plan transformUp { - case plan: SparkPlan => extract(plan) + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case plan: LogicalPlan => extract(plan) } /** * Extract all the PythonUDFs from the current operator and evaluate them before the operator. */ - private def extract(plan: SparkPlan): SparkPlan = { + private def extract(plan: LogicalPlan): LogicalPlan = { val udfs = collectEvaluableUDFsFromExpressions(plan.expressions) // ignore the PythonUDF that come from second/third aggregate, which is not used .filter(udf => udf.references.subsetOf(plan.inputSet)) @@ -151,7 +151,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val prunedChildren = plan.children.map { child => val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq if (allNeededOutput.length != child.output.length) { - ProjectExec(allNeededOutput, child) + Project(allNeededOutput, child) } else { child } @@ -180,9 +180,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { _.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF ) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => - ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) + ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => - BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) + BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child) case _ => throw new AnalysisException( "Expected either Scalar Pandas UDFs or Batched UDFs but got both") @@ -209,7 +209,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val newPlan = extract(rewritten) if (newPlan.output != plan.output) { // Trim away the new UDF value if it was only used for filtering or something. - ProjectExec(plan.output, newPlan) + Project(plan.output, newPlan) } else { newPlan } @@ -218,15 +218,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { // Split the original FilterExec to two FilterExecs. Only push down the first few predicates // that are all deterministic. - private def trySplitFilter(plan: SparkPlan): SparkPlan = { + private def trySplitFilter(plan: LogicalPlan): LogicalPlan = { plan match { - case filter: FilterExec => + case filter: Filter => val (candidates, nonDeterministic) = splitConjunctivePredicates(filter.condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_)) if (pushDown.nonEmpty) { - val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) - FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild) + val newChild = Filter(pushDown.reduceLeft(And), filter.child) + Filter((rest ++ nonDeterministic).reduceLeft(And), newChild) } else { filter } @@ -234,3 +234,13 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } } } + +object PythonEvals extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ArrowEvalPython(udfs, output, child) => + ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil + + case BatchEvalPython(udfs, output, child) => + BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil + } +}