Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {

/** A sequence of rules that will be applied in order to the physical plan before execution. */
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
python.ExtractPythonUDFs,
PlanSubqueries(sparkSession),
EnsureRequirements(sparkSession.sessionState.conf),
CollapseCodegenStages(sparkSession.sessionState.conf),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,17 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaPruning
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs}

class SparkOptimizer(
catalog: SessionCatalog,
experimentalMethods: ExperimentalMethods)
extends Optimizer(catalog) {

override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
Batch("Extract Python UDF", Once,
Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs)) :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
Batch("Parquet Schema Pruning", Once, ParquetSchemaPruning)) ++
postHocOptimizationBatches :+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy
import org.apache.spark.sql.execution.python.PythonEvals
import org.apache.spark.sql.internal.SQLConf

class SparkPlanner(
Expand All @@ -36,6 +37,7 @@ class SparkPlanner(
override def strategies: Seq[Strategy] =
experimentalMethods.extraStrategies ++
extraPlanningStrategies ++ (
PythonEvals ::
DataSourceV2Strategy ::
FileSourceStrategy ::
DataSourceStrategy(conf) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -57,7 +58,13 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
}

/**
* A physical plan that evaluates a [[PythonUDF]],
* A logical plan that evaluates a [[PythonUDF]].
*/
case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
extends UnaryNode

/**
* A physical plan that evaluates a [[PythonUDF]].
*/
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends EvalPythonExec(udfs, output, child) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@ import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}

/**
* A logical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
extends UnaryNode

/**
* A physical plan that evaluates a [[PythonUDF]]
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.{AnalysisException, Strategy}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.SparkPlan


/**
Expand Down Expand Up @@ -93,7 +93,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {

private type EvalType = Int
private type EvalTypeChecker = EvalType => Boolean
Expand Down Expand Up @@ -132,14 +132,14 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
expressions.flatMap(collectEvaluableUDFs)
}

def apply(plan: SparkPlan): SparkPlan = plan transformUp {
case plan: SparkPlan => extract(plan)
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case plan: LogicalPlan => extract(plan)
}

/**
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
*/
private def extract(plan: SparkPlan): SparkPlan = {
private def extract(plan: LogicalPlan): LogicalPlan = {
val udfs = collectEvaluableUDFsFromExpressions(plan.expressions)
// ignore the PythonUDF that come from second/third aggregate, which is not used
.filter(udf => udf.references.subsetOf(plan.inputSet))
Expand All @@ -151,7 +151,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
val prunedChildren = plan.children.map { child =>
val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq
if (allNeededOutput.length != child.output.length) {
ProjectExec(allNeededOutput, child)
Project(allNeededOutput, child)
} else {
child
}
Expand Down Expand Up @@ -180,9 +180,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
_.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
) match {
case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child)
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child)
case _ =>
throw new AnalysisException(
"Expected either Scalar Pandas UDFs or Batched UDFs but got both")
Expand All @@ -209,7 +209,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
val newPlan = extract(rewritten)
if (newPlan.output != plan.output) {
// Trim away the new UDF value if it was only used for filtering or something.
ProjectExec(plan.output, newPlan)
Project(plan.output, newPlan)
} else {
newPlan
}
Expand All @@ -218,19 +218,29 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {

// Split the original FilterExec to two FilterExecs. Only push down the first few predicates
// that are all deterministic.
private def trySplitFilter(plan: SparkPlan): SparkPlan = {
private def trySplitFilter(plan: LogicalPlan): LogicalPlan = {
plan match {
case filter: FilterExec =>
case filter: Filter =>
val (candidates, nonDeterministic) =
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_))
if (pushDown.nonEmpty) {
val newChild = FilterExec(pushDown.reduceLeft(And), filter.child)
FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild)
val newChild = Filter(pushDown.reduceLeft(And), filter.child)
Filter((rest ++ nonDeterministic).reduceLeft(And), newChild)
} else {
filter
}
case o => o
}
}
}

object PythonEvals extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ArrowEvalPython(udfs, output, child) =>
ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil

case BatchEvalPython(udfs, output, child) =>
BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil
}
}