diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fee71723ed75..1eb39f8b660b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -292,6 +292,12 @@ object SQLConf { .bytesConf(ByteUnit.BYTE) .createWithDefault(64 * 1024 * 1024) + val RUNTIME_REOPTIMIZATION_ENABLED = + buildConf("spark.sql.runtime.reoptimization.enabled") + .doc("When true, enable runtime query re-optimization.") + .booleanConf + .createWithDefault(false) + val ADAPTIVE_EXECUTION_ENABLED = buildConf("spark.sql.adaptive.enabled") .doc("When true, enable adaptive query execution.") .booleanConf @@ -1889,7 +1895,10 @@ class SQLConf extends Serializable with Logging { def targetPostShuffleInputSize: Long = getConf(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) - def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) + def runtimeReoptimizationEnabled: Boolean = getConf(RUNTIME_REOPTIMIZATION_ENABLED) + + def adaptiveExecutionEnabled: Boolean = + getConf(ADAPTIVE_EXECUTION_ENABLED) && !getConf(RUNTIME_REOPTIMIZATION_ENABLED) def minNumPostShufflePartitions: Int = getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5d2710bdc4e4..6f0b489af278 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.StringUtils.{PlanStringConcat, StringConcat} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.execution.adaptive.InsertAdaptiveSparkPlan import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -74,9 +75,15 @@ class QueryExecution( lazy val sparkPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) { SparkSession.setActiveSession(sparkSession) + // Runtime re-optimization requires a unique instance of every node in the logical plan. + val logicalPlan = if (sparkSession.sessionState.conf.runtimeReoptimizationEnabled) { + optimizedPlan.clone() + } else { + optimizedPlan + } // TODO: We use next(), i.e. take the first plan returned by the planner, here for now, // but we will implement to choose the best plan. - planner.plan(ReturnAnswer(optimizedPlan)).next() + planner.plan(ReturnAnswer(logicalPlan)).next() } // executedPlan should not be used to initialize any SparkPlan. It should be @@ -107,6 +114,9 @@ class QueryExecution( /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( + // `AdaptiveSparkPlanExec` is a leaf node. If inserted, all the following rules will be no-op + // as the original plan is hidden behind `AdaptiveSparkPlanExec`. + InsertAdaptiveSparkPlan(sparkSession), PlanSubqueries(sparkSession), EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ddcf61b882d3..fbe8e5055a25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -40,9 +40,11 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.DataType object SparkPlan { - // a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag - // when converting a logical plan to a physical plan. + /** The original [[LogicalPlan]] from which this [[SparkPlan]] is converted. */ val LOGICAL_PLAN_TAG = TreeNodeTag[LogicalPlan]("logical_plan") + + /** The [[LogicalPlan]] inherited from its ancestor. */ + val LOGICAL_PLAN_INHERITED_TAG = TreeNodeTag[LogicalPlan]("logical_plan_inherited") } /** @@ -79,6 +81,35 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ super.makeCopy(newArgs) } + /** + * @return The logical plan this plan is linked to. + */ + def logicalLink: Option[LogicalPlan] = + getTagValue(SparkPlan.LOGICAL_PLAN_TAG) + .orElse(getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)) + + /** + * Set logical plan link recursively if unset. + */ + def setLogicalLink(logicalPlan: LogicalPlan): Unit = { + setLogicalLink(logicalPlan, false) + } + + private def setLogicalLink(logicalPlan: LogicalPlan, inherited: Boolean = false): Unit = { + // Stop at a descendant which is the root of a sub-tree transformed from another logical node. + if (inherited && getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isDefined) { + return + } + + val tag = if (inherited) { + SparkPlan.LOGICAL_PLAN_INHERITED_TAG + } else { + SparkPlan.LOGICAL_PLAN_TAG + } + setTagValue(tag, logicalPlan) + children.foreach(_.setLogicalLink(logicalPlan, true)) + } + /** * @return All metrics containing metrics of this SparkPlan. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 3cd02b984d33..8c7752c4bb74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.SQLMetricInfo import org.apache.spark.sql.internal.SQLConf @@ -53,6 +54,8 @@ private[execution] object SparkPlanInfo { val children = plan match { case ReusedExchangeExec(_, child) => child :: Nil case ReusedSubqueryExec(child) => child :: Nil + case a: AdaptiveSparkPlanExec => a.executedPlan :: Nil + case stage: QueryStageExec => stage.plan :: Nil case _ => plan.children ++ plan.subqueries } val metrics = plan.metrics.toSeq.map { case (key, metric) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 2a4a1c8ef343..dc7fb7741e7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.adaptive.LogicalQueryStageStrategy import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy import org.apache.spark.sql.internal.SQLConf @@ -36,6 +37,7 @@ class SparkPlanner( override def strategies: Seq[Strategy] = experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( + LogicalQueryStageStrategy :: PythonEvals :: DataSourceV2Strategy :: FileSourceStrategy :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c4031496f610..faf2fdd7dbcc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.execution.adaptive.LogicalQueryStage import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -58,6 +59,8 @@ case class PlanLater(plan: LogicalPlan) extends LeafExecNode { protected override def doExecute(): RDD[InternalRow] = { throw new UnsupportedOperationException() } + + override def setLogicalLink(logicalPlan: LogicalPlan): Unit = {} } abstract class SparkStrategies extends QueryPlanner[SparkPlan] { @@ -69,7 +72,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ReturnAnswer(rootPlan) => rootPlan case _ => plan } - p.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, logicalPlan) + p.setLogicalLink(logicalPlan) p } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 99dcca8b3310..92e80dcf90e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.io.Writer import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable import scala.util.control.NonFatal @@ -551,56 +552,6 @@ object WholeStageCodegenExec { } } -object WholeStageCodegenId { - // codegenStageId: ID for codegen stages within a query plan. - // It does not affect equality, nor does it participate in destructuring pattern matching - // of WholeStageCodegenExec. - // - // This ID is used to help differentiate between codegen stages. It is included as a part - // of the explain output for physical plans, e.g. - // - // == Physical Plan == - // *(5) SortMergeJoin [x#3L], [y#9L], Inner - // :- *(2) Sort [x#3L ASC NULLS FIRST], false, 0 - // : +- Exchange hashpartitioning(x#3L, 200) - // : +- *(1) Project [(id#0L % 2) AS x#3L] - // : +- *(1) Filter isnotnull((id#0L % 2)) - // : +- *(1) Range (0, 5, step=1, splits=8) - // +- *(4) Sort [y#9L ASC NULLS FIRST], false, 0 - // +- Exchange hashpartitioning(y#9L, 200) - // +- *(3) Project [(id#6L % 2) AS y#9L] - // +- *(3) Filter isnotnull((id#6L % 2)) - // +- *(3) Range (0, 5, step=1, splits=8) - // - // where the ID makes it obvious that not all adjacent codegen'd plan operators are of the - // same codegen stage. - // - // The codegen stage ID is also optionally included in the name of the generated classes as - // a suffix, so that it's easier to associate a generated class back to the physical operator. - // This is controlled by SQLConf: spark.sql.codegen.useIdInClassName - // - // The ID is also included in various log messages. - // - // Within a query, a codegen stage in a plan starts counting from 1, in "insertion order". - // WholeStageCodegenExec operators are inserted into a plan in depth-first post-order. - // See CollapseCodegenStages.insertWholeStageCodegen for the definition of insertion order. - // - // 0 is reserved as a special ID value to indicate a temporary WholeStageCodegenExec object - // is created, e.g. for special fallback handling when an existing WholeStageCodegenExec - // failed to generate/compile code. - - private val codegenStageCounter: ThreadLocal[Integer] = ThreadLocal.withInitial(() => 1) - - def resetPerQuery(): Unit = codegenStageCounter.set(1) - - def getNextStageId(): Int = { - val counter = codegenStageCounter - val id = counter.get() - counter.set(id + 1) - id - } -} - /** * WholeStageCodegen compiles a subtree of plans that support codegen together into single Java * function. @@ -824,8 +775,48 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) /** * Find the chained plans that support codegen, collapse them together as WholeStageCodegen. + * + * The `codegenStageCounter` generates ID for codegen stages within a query plan. + * It does not affect equality, nor does it participate in destructuring pattern matching + * of WholeStageCodegenExec. + * + * This ID is used to help differentiate between codegen stages. It is included as a part + * of the explain output for physical plans, e.g. + * + * == Physical Plan == + * *(5) SortMergeJoin [x#3L], [y#9L], Inner + * :- *(2) Sort [x#3L ASC NULLS FIRST], false, 0 + * : +- Exchange hashpartitioning(x#3L, 200) + * : +- *(1) Project [(id#0L % 2) AS x#3L] + * : +- *(1) Filter isnotnull((id#0L % 2)) + * : +- *(1) Range (0, 5, step=1, splits=8) + * +- *(4) Sort [y#9L ASC NULLS FIRST], false, 0 + * +- Exchange hashpartitioning(y#9L, 200) + * +- *(3) Project [(id#6L % 2) AS y#9L] + * +- *(3) Filter isnotnull((id#6L % 2)) + * +- *(3) Range (0, 5, step=1, splits=8) + * + * where the ID makes it obvious that not all adjacent codegen'd plan operators are of the + * same codegen stage. + * + * The codegen stage ID is also optionally included in the name of the generated classes as + * a suffix, so that it's easier to associate a generated class back to the physical operator. + * This is controlled by SQLConf: spark.sql.codegen.useIdInClassName + * + * The ID is also included in various log messages. + * + * Within a query, a codegen stage in a plan starts counting from 1, in "insertion order". + * WholeStageCodegenExec operators are inserted into a plan in depth-first post-order. + * See CollapseCodegenStages.insertWholeStageCodegen for the definition of insertion order. + * + * 0 is reserved as a special ID value to indicate a temporary WholeStageCodegenExec object + * is created, e.g. for special fallback handling when an existing WholeStageCodegenExec + * failed to generate/compile code. */ -case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { +case class CollapseCodegenStages( + conf: SQLConf, + codegenStageCounter: AtomicInteger = new AtomicInteger(0)) + extends Rule[SparkPlan] { private def supportCodegen(e: Expression): Boolean = e match { case e: LeafExpression => true @@ -869,14 +860,13 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] => plan.withNewChildren(plan.children.map(insertWholeStageCodegen)) case plan: CodegenSupport if supportCodegen(plan) => - WholeStageCodegenExec(insertInputAdapter(plan))(WholeStageCodegenId.getNextStageId()) + WholeStageCodegenExec(insertInputAdapter(plan))(codegenStageCounter.incrementAndGet()) case other => other.withNewChildren(other.children.map(insertWholeStageCodegen)) } def apply(plan: SparkPlan): SparkPlan = { if (conf.wholeStageEnabled) { - WholeStageCodegenId.resetPerQuery() insertWholeStageCodegen(plan) } else { plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala new file mode 100644 index 000000000000..606fbd8c6385 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -0,0 +1,425 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import java.util +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.JavaConverters._ +import scala.collection.concurrent.TrieMap +import scala.collection.mutable +import scala.concurrent.ExecutionContext +import scala.util.control.NonFatal + +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ +import org.apache.spark.sql.execution.exchange._ +import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ThreadUtils + +/** + * A root node to execute the query plan adaptively. It splits the query plan into independent + * stages and executes them in order according to their dependencies. The query stage + * materializes its output at the end. When one stage completes, the data statistics of the + * materialized output will be used to optimize the remainder of the query. + * + * To create query stages, we traverse the query tree bottom up. When we hit an exchange node, + * and if all the child query stages of this exchange node are materialized, we create a new + * query stage for this exchange node. The new stage is then materialized asynchronously once it + * is created. + * + * When one query stage finishes materialization, the rest query is re-optimized and planned based + * on the latest statistics provided by all materialized stages. Then we traverse the query plan + * again and create more stages if possible. After all stages have been materialized, we execute + * the rest of the plan. + */ +case class AdaptiveSparkPlanExec( + initialPlan: SparkPlan, + @transient session: SparkSession, + @transient subqueryMap: Map[Long, ExecSubqueryExpression], + @transient stageCache: TrieMap[SparkPlan, QueryStageExec]) + extends LeafExecNode { + + @transient private val lock = new Object() + + // The logical plan optimizer for re-optimizing the current logical plan. + @transient private val optimizer = new RuleExecutor[LogicalPlan] { + // TODO add more optimization rules + override protected def batches: Seq[Batch] = Seq() + } + + // A list of physical plan rules to be applied before creation of query stages. The physical + // plan should reach a final status of query stages (i.e., no more addition or removal of + // Exchange nodes) after running these rules. + @transient private val queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq( + PlanAdaptiveSubqueries(subqueryMap), + EnsureRequirements(conf) + ) + + // A list of physical optimizer rules to be applied to a new stage before its execution. These + // optimizations should be stage-independent. + @transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq( + CollapseCodegenStages(conf) + ) + + @volatile private var currentPhysicalPlan = initialPlan + + private var isFinalPlan = false + + private var currentStageId = 0 + + /** + * Return type for `createQueryStages` + * @param newPlan the new plan with created query stages. + * @param allChildStagesMaterialized whether all child stages have been materialized. + * @param newStages the newly created query stages, including new reused query stages. + */ + private case class CreateStageResult( + newPlan: SparkPlan, + allChildStagesMaterialized: Boolean, + newStages: Seq[(Exchange, QueryStageExec)]) + + def executedPlan: SparkPlan = currentPhysicalPlan + + override def conf: SQLConf = session.sessionState.conf + + override def output: Seq[Attribute] = initialPlan.output + + override def doCanonicalize(): SparkPlan = initialPlan.canonicalized + + override def doExecute(): RDD[InternalRow] = lock.synchronized { + if (isFinalPlan) { + currentPhysicalPlan.execute() + } else { + val executionId = Option( + session.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)).map(_.toLong) + var currentLogicalPlan = currentPhysicalPlan.logicalLink.get + var result = createQueryStages(currentPhysicalPlan) + val events = new LinkedBlockingQueue[StageMaterializationEvent]() + val errors = new mutable.ArrayBuffer[SparkException]() + while (!result.allChildStagesMaterialized) { + currentPhysicalPlan = result.newPlan + currentLogicalPlan = updateLogicalPlan(currentLogicalPlan, result.newStages) + currentPhysicalPlan.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, currentLogicalPlan) + executionId.foreach(onUpdatePlan) + + // Start materialization of all new stages. + result.newStages.map(_._2).foreach { stage => + stage.materialize().onComplete { res => + if (res.isSuccess) { + events.offer(StageSuccess(stage, res.get)) + } else { + events.offer(StageFailure(stage, res.failed.get)) + } + }(AdaptiveSparkPlanExec.executionContext) + } + + // Wait on the next completed stage, which indicates new stats are available and probably + // new stages can be created. There might be other stages that finish at around the same + // time, so we process those stages too in order to reduce re-planning. + val nextMsg = events.take() + val rem = new util.ArrayList[StageMaterializationEvent]() + events.drainTo(rem) + (Seq(nextMsg) ++ rem.asScala).foreach { + case StageSuccess(stage, res) => + stage.resultOption = Some(res) + case StageFailure(stage, ex) => + errors.append( + new SparkException(s"Failed to materialize query stage: ${stage.treeString}", ex)) + } + + // In case of errors, we cancel all running stages and throw exception. + if (errors.nonEmpty) { + cleanUpAndThrowException(errors) + } + + // Do re-planning and try creating new stages on the new physical plan. + val (newPhysicalPlan, newLogicalPlan) = reOptimize(currentLogicalPlan) + currentPhysicalPlan = newPhysicalPlan + currentLogicalPlan = newLogicalPlan + result = createQueryStages(currentPhysicalPlan) + } + + // Run the final plan when there's no more unfinished stages. + currentPhysicalPlan = applyPhysicalRules(result.newPlan, queryStageOptimizerRules) + currentPhysicalPlan.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, currentLogicalPlan) + isFinalPlan = true + logDebug(s"Final plan: $currentPhysicalPlan") + executionId.foreach(onUpdatePlan) + + currentPhysicalPlan.execute() + } + } + + override def verboseString(maxFields: Int): String = simpleString(maxFields) + + override def simpleString(maxFields: Int): String = + s"AdaptiveSparkPlan(isFinalPlan=$isFinalPlan)" + + override def generateTreeString( + depth: Int, + lastChildren: Seq[Boolean], + append: String => Unit, + verbose: Boolean, + prefix: String = "", + addSuffix: Boolean = false, + maxFields: Int): Unit = { + super.generateTreeString(depth, lastChildren, append, verbose, prefix, addSuffix, maxFields) + currentPhysicalPlan.generateTreeString( + depth + 1, lastChildren :+ true, append, verbose, "", addSuffix = false, maxFields) + } + + /** + * This method is called recursively to traverse the plan tree bottom-up and create a new query + * stage or try reusing an existing stage if the current node is an [[Exchange]] node and all of + * its child stages have been materialized. + * + * With each call, it returns: + * 1) The new plan replaced with [[QueryStageExec]] nodes where new stages are created. + * 2) Whether the child query stages (if any) of the current node have all been materialized. + * 3) A list of the new query stages that have been created. + */ + private def createQueryStages(plan: SparkPlan): CreateStageResult = plan match { + case e: Exchange => + // First have a quick check in the `stageCache` without having to traverse down the node. + stageCache.get(e.canonicalized) match { + case Some(existingStage) if conf.exchangeReuseEnabled => + val reusedStage = reuseQueryStage(existingStage, e.output) + // When reusing a stage, we treat it a new stage regardless of whether the existing stage + // has been materialized or not. Thus we won't skip re-optimization for a reused stage. + CreateStageResult(newPlan = reusedStage, + allChildStagesMaterialized = false, newStages = Seq((e, reusedStage))) + + case _ => + val result = createQueryStages(e.child) + val newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange] + // Create a query stage only when all the child query stages are ready. + if (result.allChildStagesMaterialized) { + var newStage = newQueryStage(newPlan) + if (conf.exchangeReuseEnabled) { + // Check the `stageCache` again for reuse. If a match is found, ditch the new stage + // and reuse the existing stage found in the `stageCache`, otherwise update the + // `stageCache` with the new stage. + val queryStage = stageCache.getOrElseUpdate(e.canonicalized, newStage) + if (queryStage.ne(newStage)) { + newStage = reuseQueryStage(queryStage, e.output) + } + } + + // We've created a new stage, which is obviously not ready yet. + CreateStageResult(newPlan = newStage, + allChildStagesMaterialized = false, newStages = Seq((e, newStage))) + } else { + CreateStageResult(newPlan = newPlan, + allChildStagesMaterialized = false, newStages = result.newStages) + } + } + + case q: QueryStageExec => + CreateStageResult(newPlan = q, + allChildStagesMaterialized = q.resultOption.isDefined, newStages = Seq.empty) + + case _ => + if (plan.children.isEmpty) { + CreateStageResult(newPlan = plan, allChildStagesMaterialized = true, newStages = Seq.empty) + } else { + val results = plan.children.map(createQueryStages) + CreateStageResult( + newPlan = plan.withNewChildren(results.map(_.newPlan)), + allChildStagesMaterialized = results.forall(_.allChildStagesMaterialized), + newStages = results.flatMap(_.newStages)) + } + } + + private def newQueryStage(e: Exchange): QueryStageExec = { + val optimizedPlan = applyPhysicalRules(e.child, queryStageOptimizerRules) + val queryStage = e match { + case s: ShuffleExchangeExec => + ShuffleQueryStageExec(currentStageId, s.copy(child = optimizedPlan)) + case b: BroadcastExchangeExec => + BroadcastQueryStageExec(currentStageId, b.copy(child = optimizedPlan)) + } + currentStageId += 1 + queryStage + } + + private def reuseQueryStage(s: QueryStageExec, output: Seq[Attribute]): QueryStageExec = { + val queryStage = ReusedQueryStageExec(currentStageId, s, output) + currentStageId += 1 + queryStage + } + + /** + * Returns the updated logical plan after new query stages have been created and the physical + * plan has been updated with the newly created stages. + * 1. If the new query stage can be mapped to an integral logical sub-tree, replace the + * corresponding logical sub-tree with a leaf node [[LogicalQueryStage]] referencing the new + * query stage. For example: + * Join SMJ SMJ + * / \ / \ / \ + * r1 r2 => Xchg1 Xchg2 => Stage1 Stage2 + * | | + * r1 r2 + * The updated plan node will be: + * Join + * / \ + * LogicalQueryStage1(Stage1) LogicalQueryStage2(Stage2) + * + * 2. Otherwise (which means the new query stage can only be mapped to part of a logical + * sub-tree), replace the corresponding logical sub-tree with a leaf node + * [[LogicalQueryStage]] referencing to the top physical node into which this logical node is + * transformed during physical planning. For example: + * Agg HashAgg HashAgg + * | | | + * child => Xchg => Stage1 + * | + * HashAgg + * | + * child + * The updated plan node will be: + * LogicalQueryStage(HashAgg - Stage1) + */ + private def updateLogicalPlan( + logicalPlan: LogicalPlan, + newStages: Seq[(Exchange, QueryStageExec)]): LogicalPlan = { + var currentLogicalPlan = logicalPlan + newStages.foreach { + case (exchange, stage) => + // Get the corresponding logical node for `exchange`. If `exchange` has been transformed + // from a `Repartition`, it should have `logicalLink` available by itself; otherwise + // traverse down to find the first node that is not generated by `EnsureRequirements`. + val logicalNodeOpt = exchange.logicalLink.orElse(exchange.collectFirst { + case p if p.logicalLink.isDefined => p.logicalLink.get + }) + assert(logicalNodeOpt.isDefined) + val logicalNode = logicalNodeOpt.get + val physicalNode = currentPhysicalPlan.collectFirst { + case p if p.eq(stage) || p.logicalLink.exists(logicalNode.eq) => p + } + assert(physicalNode.isDefined) + // Replace the corresponding logical node with LogicalQueryStage + val newLogicalNode = LogicalQueryStage(logicalNode, physicalNode.get) + val newLogicalPlan = currentLogicalPlan.transformDown { + case p if p.eq(logicalNode) => newLogicalNode + } + assert(newLogicalPlan != currentLogicalPlan, + s"logicalNode: $logicalNode; " + + s"logicalPlan: $currentLogicalPlan " + + s"physicalPlan: $currentPhysicalPlan" + + s"stage: $stage") + currentLogicalPlan = newLogicalPlan + } + currentLogicalPlan + } + + /** + * Re-optimize and run physical planning on the current logical plan based on the latest stats. + */ + private def reOptimize(logicalPlan: LogicalPlan): (SparkPlan, LogicalPlan) = { + logicalPlan.invalidateStatsCache() + val optimized = optimizer.execute(logicalPlan) + SparkSession.setActiveSession(session) + val sparkPlan = session.sessionState.planner.plan(ReturnAnswer(optimized)).next() + val newPlan = applyPhysicalRules(sparkPlan, queryStagePreparationRules) + (newPlan, optimized) + } + + /** + * Notify the listeners of the physical plan change. + */ + private def onUpdatePlan(executionId: Long): Unit = { + session.sparkContext.listenerBus.post(SparkListenerSQLAdaptiveExecutionUpdate( + executionId, + SQLExecution.getQueryExecution(executionId).toString, + SparkPlanInfo.fromSparkPlan(this))) + } + + /** + * Cancel all running stages with best effort and throw an Exception containing all stage + * materialization errors and stage cancellation errors. + */ + private def cleanUpAndThrowException(errors: Seq[SparkException]): Unit = { + val runningStages = currentPhysicalPlan.collect { + case s: QueryStageExec => s + } + val cancelErrors = new mutable.ArrayBuffer[SparkException]() + try { + runningStages.foreach { s => + try { + s.cancel() + } catch { + case NonFatal(t) => + cancelErrors.append( + new SparkException(s"Failed to cancel query stage: ${s.treeString}", t)) + } + } + } finally { + val ex = new SparkException( + "Adaptive execution failed due to stage materialization failures.", errors.head) + errors.tail.foreach(ex.addSuppressed) + cancelErrors.foreach(ex.addSuppressed) + throw ex + } + } +} + +object AdaptiveSparkPlanExec { + private val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("QueryStageCreator", 16)) + + /** + * Creates the list of physical plan rules to be applied before creation of query stages. + */ + def createQueryStagePreparationRules( + conf: SQLConf, + subqueryMap: Map[Long, ExecSubqueryExpression]): Seq[Rule[SparkPlan]] = { + Seq( + PlanAdaptiveSubqueries(subqueryMap), + EnsureRequirements(conf)) + } + + /** + * Apply a list of physical operator rules on a [[SparkPlan]]. + */ + def applyPhysicalRules(plan: SparkPlan, rules: Seq[Rule[SparkPlan]]): SparkPlan = { + rules.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } + } +} + +/** + * The event type for stage materialization. + */ +sealed trait StageMaterializationEvent + +/** + * The materialization of a query stage completed with success. + */ +case class StageSuccess(stage: QueryStageExec, result: Any) extends StageMaterializationEvent + +/** + * The materialization of a query stage hit an error and failed. + */ +case class StageFailure(stage: QueryStageExec, error: Throwable) extends StageMaterializationEvent diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala new file mode 100644 index 000000000000..a1b0e291c1b6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.concurrent.TrieMap +import scala.collection.mutable + +import org.apache.spark.sql.{execution, SparkSession} +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.command.ExecutedCommandExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +/** + * This rule wraps the query plan with an [[AdaptiveSparkPlanExec]], which executes the query plan + * and re-optimize the plan during execution based on runtime data statistics. + * + * Note that this rule is stateful and thus should not be reused across query executions. + */ +case class InsertAdaptiveSparkPlan(session: SparkSession) extends Rule[SparkPlan] { + + private val conf = session.sessionState.conf + + // Exchange-reuse is shared across the entire query, including sub-queries. + private val stageCache = new TrieMap[SparkPlan, QueryStageExec]() + + override def apply(plan: SparkPlan): SparkPlan = plan match { + case _: ExecutedCommandExec => plan + case _ if conf.runtimeReoptimizationEnabled && supportAdaptive(plan) => + try { + // Plan sub-queries recursively and pass in the shared stage cache for exchange reuse. Fall + // back to non-adaptive mode if adaptive execution is supported in any of the sub-queries. + val subqueryMap = buildSubqueryMap(plan) + // Run preparation rules. + val preparations = AdaptiveSparkPlanExec.createQueryStagePreparationRules( + session.sessionState.conf, subqueryMap) + val newPlan = AdaptiveSparkPlanExec.applyPhysicalRules(plan, preparations) + logDebug(s"Adaptive execution enabled for plan: $plan") + AdaptiveSparkPlanExec(newPlan, session, subqueryMap, stageCache) + } catch { + case SubqueryAdaptiveNotSupportedException(subquery) => + logWarning(s"${SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key} is enabled " + + s"but is not supported for sub-query: $subquery.") + plan + } + case _ => + if (conf.runtimeReoptimizationEnabled) { + logWarning(s"${SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key} is enabled " + + s"but is not supported for query: $plan.") + } + plan + } + + private def supportAdaptive(plan: SparkPlan): Boolean = { + sanityCheck(plan) && + !plan.logicalLink.exists(_.isStreaming) && + plan.children.forall(supportAdaptive) + } + + private def sanityCheck(plan: SparkPlan): Boolean = + plan.logicalLink.isDefined + + /** + * Returns an expression-id-to-execution-plan map for all the sub-queries. + * For each sub-query, generate the adaptive execution plan for each sub-query by applying this + * rule, or reuse the execution plan from another sub-query of the same semantics if possible. + */ + private def buildSubqueryMap(plan: SparkPlan): Map[Long, ExecSubqueryExpression] = { + val subqueryMapBuilder = mutable.HashMap.empty[Long, ExecSubqueryExpression] + plan.foreach(_.expressions.foreach(_.foreach { + case expressions.ScalarSubquery(p, _, exprId) + if !subqueryMapBuilder.contains(exprId.id) => + val executedPlan = getExecutedPlan(p) + val scalarSubquery = execution.ScalarSubquery( + SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId) + subqueryMapBuilder.put(exprId.id, scalarSubquery) + case _ => + })) + + // Reuse subqueries + if (session.sessionState.conf.subqueryReuseEnabled) { + // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls. + val reuseMap = mutable.HashMap[StructType, mutable.ArrayBuffer[BaseSubqueryExec]]() + subqueryMapBuilder.keySet.foreach { exprId => + val sub = subqueryMapBuilder(exprId) + val sameSchema = + reuseMap.getOrElseUpdate(sub.plan.schema, mutable.ArrayBuffer.empty) + val sameResult = sameSchema.find(_.sameResult(sub.plan)) + if (sameResult.isDefined) { + val newExpr = sub.withNewPlan(ReusedSubqueryExec(sameResult.get)) + subqueryMapBuilder.update(exprId, newExpr) + } else { + sameSchema += sub.plan + } + } + } + + subqueryMapBuilder.toMap + } + + private def getExecutedPlan(plan: LogicalPlan): SparkPlan = { + val queryExec = new QueryExecution(session, plan) + // Apply the same instance of this rule to sub-queries so that sub-queries all share the + // same `stageCache` for Exchange reuse. + val adaptivePlan = this.apply(queryExec.sparkPlan) + if (!adaptivePlan.isInstanceOf[AdaptiveSparkPlanExec]) { + throw SubqueryAdaptiveNotSupportedException(plan) + } + adaptivePlan + } +} + +private case class SubqueryAdaptiveNotSupportedException(plan: LogicalPlan) extends Exception {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala new file mode 100644 index 000000000000..9914eddd53a3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.SparkPlan + +/** + * The LogicalPlan wrapper for a [[QueryStageExec]], or a snippet of physical plan containing + * a [[QueryStageExec]], in which all ancestor nodes of the [[QueryStageExec]] are linked to + * the same logical node. + * + * For example, a logical Aggregate can be transformed into FinalAgg - Shuffle - PartialAgg, in + * which the Shuffle will be wrapped into a [[QueryStageExec]], thus the [[LogicalQueryStage]] + * will have FinalAgg - QueryStageExec as its physical plan. + */ +// TODO we can potentially include only [[QueryStageExec]] in this class if we make the aggregation +// planning aware of partitioning. +case class LogicalQueryStage( + logicalPlan: LogicalPlan, + physicalPlan: SparkPlan) extends LeafNode { + + override def output: Seq[Attribute] = logicalPlan.output + override val isStreaming: Boolean = logicalPlan.isStreaming + override val outputOrdering: Seq[SortOrder] = physicalPlan.outputOrdering + + override def computeStats(): Statistics = { + // TODO this is not accurate when there is other physical nodes above QueryStageExec. + val physicalStats = physicalPlan.collectFirst { + case s: QueryStageExec => s + }.flatMap(_.computeStats()) + if (physicalStats.isDefined) { + logDebug(s"Physical stats available as ${physicalStats.get} for plan: $physicalPlan") + } else { + logDebug(s"Physical stats not available for plan: $physicalPlan") + } + physicalStats.getOrElse(logicalPlan.stats) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala new file mode 100644 index 000000000000..a0d07a68ab0f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, BuildLeft, BuildRight} + +/** + * Strategy for plans containing [[LogicalQueryStage]] nodes: + * 1. Transforms [[LogicalQueryStage]] to its corresponding physical plan that is either being + * executed or has already completed execution. + * 2. Transforms [[Join]] which has one child relation already planned and executed as a + * [[BroadcastQueryStageExec]]. This is to prevent reversing a broadcast stage into a shuffle + * stage in case of the larger join child relation finishes before the smaller relation. Note + * that this rule needs to applied before regular join strategies. + */ +object LogicalQueryStageStrategy extends Strategy with PredicateHelper { + + private def isBroadcastStage(plan: LogicalPlan): Boolean = plan match { + case LogicalQueryStage(_, physicalPlan) + if BroadcastQueryStageExec.isBroadcastQueryStageExec(physicalPlan) => + true + case _ => false + } + + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) + if isBroadcastStage(left) || isBroadcastStage(right) => + val buildSide = if (isBroadcastStage(left)) BuildLeft else BuildRight + Seq(BroadcastHashJoinExec( + leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + + case j @ Join(left, right, joinType, condition, _) + if isBroadcastStage(left) || isBroadcastStage(right) => + val buildSide = if (isBroadcastStage(left)) BuildLeft else BuildRight + BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + + case q: LogicalQueryStage => + q.physicalPlan :: Nil + + case _ => Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala new file mode 100644 index 000000000000..4af7432d7bed --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.ListQuery +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ExecSubqueryExpression, SparkPlan} + +case class PlanAdaptiveSubqueries( + subqueryMap: Map[Long, ExecSubqueryExpression]) extends Rule[SparkPlan] { + + def apply(plan: SparkPlan): SparkPlan = { + plan.transformAllExpressions { + case expressions.ScalarSubquery(_, _, exprId) => + subqueryMap(exprId.id) + case expressions.InSubquery(_, ListQuery(_, _, exprId, _)) => + subqueryMap(exprId.id) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala new file mode 100644 index 000000000000..98cb7d0ca943 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.concurrent.Future + +import org.apache.spark.{FutureAction, MapOutputStatistics} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.exchange._ + + +/** + * A query stage is an independent subgraph of the query plan. Query stage materializes its output + * before proceeding with further operators of the query plan. The data statistics of the + * materialized output can be used to optimize subsequent query stages. + * + * There are 2 kinds of query stages: + * 1. Shuffle query stage. This stage materializes its output to shuffle files, and Spark launches + * another job to execute the further operators. + * 2. Broadcast query stage. This stage materializes its output to an array in driver JVM. Spark + * broadcasts the array before executing the further operators. + */ +abstract class QueryStageExec extends LeafExecNode { + + /** + * An id of this query stage which is unique in the entire query plan. + */ + val id: Int + + /** + * The sub-tree of the query plan that belongs to this query stage. + */ + val plan: SparkPlan + + /** + * Materialize this query stage, to prepare for the execution, like submitting map stages, + * broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this + * stage is ready. + */ + def doMaterialize(): Future[Any] + + /** + * Cancel the stage materialization if in progress; otherwise do nothing. + */ + def cancel(): Unit + + /** + * Materialize this query stage, to prepare for the execution, like submitting map stages, + * broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this + * stage is ready. + */ + final def materialize(): Future[Any] = executeQuery { + doMaterialize() + } + + /** + * Compute the statistics of the query stage if executed, otherwise None. + */ + def computeStats(): Option[Statistics] = resultOption.map { _ => + // Metrics `dataSize` are available in both `ShuffleExchangeExec` and `BroadcastExchangeExec`. + Statistics(sizeInBytes = plan.metrics("dataSize").value) + } + + @transient + @volatile + private[adaptive] var resultOption: Option[Any] = None + + override def output: Seq[Attribute] = plan.output + override def outputPartitioning: Partitioning = plan.outputPartitioning + override def outputOrdering: Seq[SortOrder] = plan.outputOrdering + override def executeCollect(): Array[InternalRow] = plan.executeCollect() + override def executeTake(n: Int): Array[InternalRow] = plan.executeTake(n) + override def executeToIterator(): Iterator[InternalRow] = plan.executeToIterator() + + override def doPrepare(): Unit = plan.prepare() + override def doExecute(): RDD[InternalRow] = plan.execute() + override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast() + override def doCanonicalize(): SparkPlan = plan.canonicalized + + protected override def stringArgs: Iterator[Any] = Iterator.single(id) + + override def generateTreeString( + depth: Int, + lastChildren: Seq[Boolean], + append: String => Unit, + verbose: Boolean, + prefix: String = "", + addSuffix: Boolean = false, + maxFields: Int): Unit = { + super.generateTreeString(depth, lastChildren, append, verbose, prefix, addSuffix, maxFields) + plan.generateTreeString( + depth + 1, lastChildren :+ true, append, verbose, "", false, maxFields) + } +} + +/** + * A shuffle query stage whose child is a [[ShuffleExchangeExec]]. + */ +case class ShuffleQueryStageExec( + override val id: Int, + override val plan: ShuffleExchangeExec) extends QueryStageExec { + + @transient lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { + if (plan.inputRDD.getNumPartitions == 0) { + Future.successful(null) + } else { + sparkContext.submitMapStage(plan.shuffleDependency) + } + } + + override def doMaterialize(): Future[Any] = { + mapOutputStatisticsFuture + } + + override def cancel(): Unit = { + mapOutputStatisticsFuture match { + case action: FutureAction[MapOutputStatistics] if !mapOutputStatisticsFuture.isCompleted => + action.cancel() + case _ => + } + } +} + +/** + * A broadcast query stage whose child is a [[BroadcastExchangeExec]]. + */ +case class BroadcastQueryStageExec( + override val id: Int, + override val plan: BroadcastExchangeExec) extends QueryStageExec { + + override def doMaterialize(): Future[Any] = { + plan.completionFuture + } + + override def cancel(): Unit = { + if (!plan.relationFuture.isDone) { + sparkContext.cancelJobGroup(plan.runId.toString) + plan.relationFuture.cancel(true) + } + } +} + +object BroadcastQueryStageExec { + /** + * Returns if the plan is a [[BroadcastQueryStageExec]] or a reused [[BroadcastQueryStageExec]]. + */ + def isBroadcastQueryStageExec(plan: SparkPlan): Boolean = plan match { + case r: ReusedQueryStageExec => isBroadcastQueryStageExec(r.plan) + case _: BroadcastQueryStageExec => true + case _ => false + } +} + +/** + * A wrapper for reused query stage to have different output. + */ +case class ReusedQueryStageExec( + override val id: Int, + override val plan: QueryStageExec, + override val output: Seq[Attribute]) extends QueryStageExec { + + override def doMaterialize(): Future[Any] = { + plan.materialize() + } + + override def cancel(): Unit = { + plan.cancel() + } + + // `ReusedQueryStageExec` can have distinct set of output attribute ids from its child, we need + // to update the attribute ids in `outputPartitioning` and `outputOrdering`. + private lazy val updateAttr: Expression => Expression = { + val originalAttrToNewAttr = AttributeMap(plan.output.zip(output)) + e => e.transform { + case attr: Attribute => originalAttrToNewAttr.getOrElse(attr, attr) + } + } + + override def outputPartitioning: Partitioning = plan.outputPartitioning match { + case e: Expression => updateAttr(e).asInstanceOf[Partitioning] + case other => other + } + + override def outputOrdering: Seq[SortOrder] = { + plan.outputOrdering.map(updateAttr(_).asInstanceOf[SortOrder]) + } + + override def computeStats(): Option[Statistics] = plan.computeStats() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 8017188eb165..36f0d173cd0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.exchange import java.util.UUID import java.util.concurrent._ -import scala.concurrent.ExecutionContext +import scala.concurrent.{ExecutionContext, Promise} import scala.concurrent.duration.NANOSECONDS import scala.util.control.NonFatal @@ -44,7 +44,7 @@ case class BroadcastExchangeExec( mode: BroadcastMode, child: SparkPlan) extends Exchange { - private val runId: UUID = UUID.randomUUID + private[sql] val runId: UUID = UUID.randomUUID override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), @@ -58,11 +58,21 @@ case class BroadcastExchangeExec( BroadcastExchangeExec(mode.canonicalized, child.canonicalized) } + @transient + private lazy val promise = Promise[broadcast.Broadcast[Any]]() + + /** + * For registering callbacks on `relationFuture`. + * Note that calling this field will not start the execution of broadcast job. + */ + @transient + lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = promise.future + @transient private val timeout: Long = SQLConf.get.broadcastTimeout @transient - private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) val task = new Callable[broadcast.Broadcast[Any]]() { @@ -113,20 +123,28 @@ case class BroadcastExchangeExec( System.nanoTime() - beforeBroadcast) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + promise.success(broadcasted) broadcasted } catch { // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult // will catch this exception and re-throw the wrapped fatal throwable. case oe: OutOfMemoryError => - throw new SparkFatalException( + val ex = new SparkFatalException( new OutOfMemoryError("Not enough memory to build and broadcast the table to all " + "worker nodes. As a workaround, you can either disable broadcast by setting " + s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark " + s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value.") .initCause(oe.getCause)) + promise.failure(ex) + throw ex case e if !NonFatal(e) => - throw new SparkFatalException(e) + val ex = new SparkFatalException(e) + promise.failure(ex) + throw ex + case e: Throwable => + promise.failure(e) + throw e } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index e4ec76f0b9a1..c99bf458230e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -209,7 +209,7 @@ class ExchangeCoordinator( var i = 0 while (i < numExchanges) { val exchange = exchanges(i) - val shuffleDependency = exchange.prepareShuffleDependency() + val shuffleDependency = exchange.shuffleDependency shuffleDependencies += shuffleDependency if (shuffleDependency.rdd.partitions.length != 0) { // submitMapStage does not accept RDD with 0 partition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 16398e34bdeb..31f75e3fb937 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -87,15 +87,17 @@ case class ShuffleExchangeExec( } } + @transient lazy val inputRDD: RDD[InternalRow] = child.execute() + /** - * Returns a [[ShuffleDependency]] that will partition rows of its child based on + * A [[ShuffleDependency]] that will partition rows of its child based on * the partitioning scheme defined in `newPartitioning`. Those partitions of * the returned ShuffleDependency will be the input of shuffle. */ - private[exchange] def prepareShuffleDependency() - : ShuffleDependency[Int, InternalRow, InternalRow] = { + @transient + lazy val shuffleDependency : ShuffleDependency[Int, InternalRow, InternalRow] = { ShuffleExchangeExec.prepareShuffleDependency( - child.execute(), + inputRDD, child.output, newPartitioning, serializer, @@ -135,7 +137,6 @@ case class ShuffleExchangeExec( assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) shuffleRDD case _ => - val shuffleDependency = prepareShuffleDependency() preparePostShuffleRDD(shuffleDependency) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 960d47b3ac87..064e0a098955 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -37,6 +37,9 @@ abstract class ExecSubqueryExpression extends PlanExpression[BaseSubqueryExec] { */ def updateResult(): Unit + /** Updates the expression with a new plan. */ + override def withNewPlan(plan: BaseSubqueryExec): ExecSubqueryExpression + override def canonicalize(attrs: AttributeSeq): ExecSubqueryExpression = { withNewPlan(plan.canonicalized.asInstanceOf[BaseSubqueryExec]) .asInstanceOf[ExecSubqueryExpression] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index e496de1b05e4..2c4a7eacdf10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -248,26 +248,26 @@ class SQLAppStatusListener( } } + private def toStoredNodes(nodes: Seq[SparkPlanGraphNode]): Seq[SparkPlanGraphNodeWrapper] = { + nodes.map { + case cluster: SparkPlanGraphCluster => + val storedCluster = new SparkPlanGraphClusterWrapper( + cluster.id, + cluster.name, + cluster.desc, + toStoredNodes(cluster.nodes), + cluster.metrics) + new SparkPlanGraphNodeWrapper(null, storedCluster) + + case node => + new SparkPlanGraphNodeWrapper(node, null) + } + } + private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = { val SparkListenerSQLExecutionStart(executionId, description, details, physicalPlanDescription, sparkPlanInfo, time) = event - def toStoredNodes(nodes: Seq[SparkPlanGraphNode]): Seq[SparkPlanGraphNodeWrapper] = { - nodes.map { - case cluster: SparkPlanGraphCluster => - val storedCluster = new SparkPlanGraphClusterWrapper( - cluster.id, - cluster.name, - cluster.desc, - toStoredNodes(cluster.nodes), - cluster.metrics) - new SparkPlanGraphNodeWrapper(null, storedCluster) - - case node => - new SparkPlanGraphNodeWrapper(node, null) - } - } - val planGraph = SparkPlanGraph(sparkPlanInfo) val sqlPlanMetrics = planGraph.allNodes.flatMap { node => node.metrics.map { metric => (metric.accumulatorId, metric) } @@ -288,6 +288,27 @@ class SQLAppStatusListener( update(exec) } + private def onAdaptiveExecutionUpdate(event: SparkListenerSQLAdaptiveExecutionUpdate): Unit = { + val SparkListenerSQLAdaptiveExecutionUpdate( + executionId, physicalPlanDescription, sparkPlanInfo) = event + + val planGraph = SparkPlanGraph(sparkPlanInfo) + val sqlPlanMetrics = planGraph.allNodes.flatMap { node => + node.metrics.map { metric => (metric.accumulatorId, metric) } + }.toMap.values.toList + + val graphToStore = new SparkPlanGraphWrapper( + executionId, + toStoredNodes(planGraph.nodes), + planGraph.edges) + kvstore.write(graphToStore) + + val exec = getOrCreateExecution(executionId) + exec.physicalPlanDescription = physicalPlanDescription + exec.metrics = sqlPlanMetrics + update(exec) + } + private def onExecutionEnd(event: SparkListenerSQLExecutionEnd): Unit = { val SparkListenerSQLExecutionEnd(executionId, time) = event Option(liveExecutions.get(executionId)).foreach { exec => @@ -320,6 +341,7 @@ class SQLAppStatusListener( override def onOtherEvent(event: SparkListenerEvent): Unit = event match { case e: SparkListenerSQLExecutionStart => onExecutionStart(e) + case e: SparkListenerSQLAdaptiveExecutionUpdate => onAdaptiveExecutionUpdate(e) case e: SparkListenerSQLExecutionEnd => onExecutionEnd(e) case e: SparkListenerDriverAccumUpdates => onDriverAccumUpdates(e) case _ => // Ignore diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 03d75c4c1b82..67d1f27271b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -27,6 +27,13 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.{QueryExecution, SparkPlanInfo} +@DeveloperApi +case class SparkListenerSQLAdaptiveExecutionUpdate( + executionId: Long, + physicalPlanDescription: String, + sparkPlanInfo: SparkPlanInfo) + extends SparkListenerEvent + @DeveloperApi case class SparkListenerSQLExecutionStart( executionId: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index b864ad1c7108..2b7597ee66d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -96,6 +96,15 @@ object SparkPlanGraph { case "InputAdapter" => buildSparkPlanGraphNode( planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) + case "BroadcastQueryStage" | "ShuffleQueryStage" => + if (exchanges.contains(planInfo.children.head)) { + // Point to the re-used exchange + val node = exchanges(planInfo.children.head) + edges += SparkPlanGraphEdge(node.id, parent.id) + } else { + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) + } case "Subquery" if subgraph != null => // Subquery should not be included in WholeStageCodegen buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala new file mode 100644 index 000000000000..2cddf7cd0f65 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.{ReusedSubqueryExec, SparkPlan} +import org.apache.spark.sql.execution.exchange.Exchange +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class AdaptiveQueryExecSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + setupTestData() + + private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = { + val dfAdaptive = sql(query) + val planBefore = dfAdaptive.queryExecution.executedPlan + assert(planBefore.toString.startsWith("AdaptiveSparkPlan(isFinalPlan=false)")) + val result = dfAdaptive.collect() + withSQLConf(SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "false") { + val df = sql(query) + QueryTest.sameRows(result.toSeq, df.collect().toSeq) + } + val planAfter = dfAdaptive.queryExecution.executedPlan + assert(planAfter.toString.startsWith("AdaptiveSparkPlan(isFinalPlan=true)")) + val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val exchanges = adaptivePlan.collect { + case e: Exchange => e + } + assert(exchanges.isEmpty, "The final plan should not contain any Exchange node.") + (dfAdaptive.queryExecution.sparkPlan, adaptivePlan) + } + + private def findTopLevelBroadcastHashJoin(plan: SparkPlan): Seq[BroadcastHashJoinExec] = { + plan.collect { + case j: BroadcastHashJoinExec => Seq(j) + case s: QueryStageExec => findTopLevelBroadcastHashJoin(s.plan) + }.flatten + } + + private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = { + plan.collect { + case j: SortMergeJoinExec => Seq(j) + case s: QueryStageExec => findTopLevelSortMergeJoin(s.plan) + }.flatten + } + + private def findReusedExchange(plan: SparkPlan): Seq[ReusedQueryStageExec] = { + plan.collect { + case e: ReusedQueryStageExec => Seq(e) + case a: AdaptiveSparkPlanExec => findReusedExchange(a.executedPlan) + case s: QueryStageExec => findReusedExchange(s.plan) + case p: SparkPlan => p.subqueries.flatMap(findReusedExchange) + }.flatten + } + + private def findReusedSubquery(plan: SparkPlan): Seq[ReusedSubqueryExec] = { + plan.collect { + case e: ReusedSubqueryExec => Seq(e) + case s: QueryStageExec => findReusedSubquery(s.plan) + case p: SparkPlan => p.subqueries.flatMap(findReusedSubquery) + }.flatten + } + + test("Change merge join to broadcast join") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData join testData2 ON key = a where value = '1'") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + } + } + + test("Scalar subquery") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData join testData2 ON key = a " + + "where value = (SELECT max(a) from testData3)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + } + } + + test("Scalar subquery in later stages") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData join testData2 ON key = a " + + "where (value + a) = (SELECT max(a) from testData3)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + } + } + + test("multiple joins") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |WITH t4 AS ( + | SELECT * FROM lowercaseData t2 JOIN testData3 t3 ON t2.n = t3.a + |) + |SELECT * FROM testData + |JOIN testData2 t2 ON key = t2.a + |JOIN t4 ON key = t4.a + |WHERE value = 1 + """.stripMargin) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 3) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 2) + } + } + + test("multiple joins with aggregate") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |WITH t4 AS ( + | SELECT * FROM lowercaseData t2 JOIN ( + | select a, sum(b) from testData3 group by a + | ) t3 ON t2.n = t3.a + |) + |SELECT * FROM testData + |JOIN testData2 t2 ON key = t2.a + |JOIN t4 ON key = t4.a + |WHERE value = 1 + """.stripMargin) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 3) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 2) + } + } + + test("multiple joins with aggregate 2") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |WITH t4 AS ( + | SELECT * FROM lowercaseData t2 JOIN ( + | select a, max(b) b from testData2 group by a + | ) t3 ON t2.n = t3.b + |) + |SELECT * FROM testData + |JOIN testData2 t2 ON key = t2.a + |JOIN t4 ON value = t4.a + |WHERE value = 1 + """.stripMargin) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 3) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 3) + } + } + + test("Exchange reuse") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT value FROM testData join testData2 ON key = a " + + "join (SELECT value v from testData join testData3 ON key = a) on value = v") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 3) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 2) + val ex = findReusedExchange(adaptivePlan) + assert(ex.size == 1) + } + } + + test("Exchange reuse with subqueries") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT a FROM testData join testData2 ON key = a " + + "where value = (SELECT max(a) from testData join testData2 ON key = a)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + val ex = findReusedExchange(adaptivePlan) + assert(ex.size == 1) + } + } + + test("Exchange reuse across subqueries") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", + SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT a FROM testData join testData2 ON key = a " + + "where value >= (SELECT max(a) from testData join testData2 ON key = a) " + + "and a <= (SELECT max(a) from testData join testData2 ON key = a)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + val ex = findReusedExchange(adaptivePlan) + assert(ex.nonEmpty) + val sub = findReusedSubquery(adaptivePlan) + assert(sub.isEmpty) + } + } + + test("Subquery reuse") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT a FROM testData join testData2 ON key = a " + + "where value >= (SELECT max(a) from testData join testData2 ON key = a) " + + "and a <= (SELECT max(a) from testData join testData2 ON key = a)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + val ex = findReusedExchange(adaptivePlan) + assert(ex.isEmpty) + val sub = findReusedSubquery(adaptivePlan) + assert(sub.nonEmpty) + } + } + + test("Broadcast exchange reuse across subqueries") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000", + SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT a FROM testData join testData2 ON key = a " + + "where value >= (" + + "SELECT /*+ broadcast(testData2) */ max(key) from testData join testData2 ON key = a) " + + "and a <= (" + + "SELECT /*+ broadcast(testData2) */ max(value) from testData join testData2 ON key = a)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + val ex = findReusedExchange(adaptivePlan) + assert(ex.nonEmpty) + assert(ex.head.plan.isInstanceOf[BroadcastQueryStageExec]) + val sub = findReusedSubquery(adaptivePlan) + assert(sub.isEmpty) + } + } + + test("Union/Except/Intersect queries") { + withSQLConf(SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true") { + runAdaptiveAndVerifyResult( + """ + |SELECT * FROM testData + |EXCEPT + |SELECT * FROM testData2 + |UNION ALL + |SELECT * FROM testData + |INTERSECT ALL + |SELECT * FROM testData2 + """.stripMargin) + } + } + + test("Subquery de-correlation in Union queries") { + withSQLConf(SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true") { + withTempView("a", "b") { + Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a") + Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b") + + runAdaptiveAndVerifyResult( + """ + |SELECT id,num,source FROM ( + | SELECT id, num, 'a' as source FROM a + | UNION ALL + | SELECT id, num, 'b' as source FROM b + |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) + """.stripMargin) + } + } + } +}