diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 112090640040a..ab867c7556eb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -374,12 +374,14 @@ case class AdaptiveSparkPlanExec( } private def newQueryStage(e: Exchange): QueryStageExec = { - val optimizedPlan = applyPhysicalRules(e.child, queryStageOptimizerRules) + // apply optimizer rules to the Exchange node and its children, allowing plugins to be + // able to replace the Exchange node itself + val optimizedPlan = applyPhysicalRules(e, queryStageOptimizerRules) val queryStage = e match { - case s: ShuffleExchangeExec => - ShuffleQueryStageExec(currentStageId, s.copy(child = optimizedPlan)) - case b: BroadcastExchangeExec => - BroadcastQueryStageExec(currentStageId, b.copy(child = optimizedPlan)) + case _: ShuffleExchange => + ShuffleQueryStageExec(currentStageId, optimizedPlan) + case _: BroadcastExchange => + BroadcastQueryStageExec(currentStageId, optimizedPlan) } currentStageId += 1 setLogicalLinkForNewQueryStage(queryStage, e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala index af18ee065aa86..36faf86e07666 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -55,9 +55,9 @@ case class CustomShuffleReaderExec private( partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size == partitionSpecs.length) { child match { - case ShuffleQueryStageExec(_, s: ShuffleExchangeExec) => + case ShuffleQueryStageExec(_, s: ShuffleExchange) => s.child.outputPartitioning - case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeExec)) => + case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchange)) => s.child.outputPartitioning match { case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning] case other => other @@ -180,8 +180,10 @@ case class CustomShuffleReaderExec private( sendDriverMetrics() shuffleStage.map { stage => + val shuffleExchangeExec = stage.shuffle.asInstanceOf[ShuffleExchangeExec] new ShuffledRowRDD( - stage.shuffle.shuffleDependency, stage.shuffle.readMetrics, partitionSpecs.toArray) + shuffleExchangeExec.shuffleDependency, + shuffleExchangeExec.readMetrics, partitionSpecs.toArray) }.getOrElse { throw new IllegalStateException("operating on canonicalized plan") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index 3620f27058af2..9d50fb4cf260e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.internal.SQLConf @@ -78,10 +78,9 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { private def getPartitionSpecs( shuffleStage: ShuffleQueryStageExec, advisoryParallelism: Option[Int]): Seq[ShufflePartitionSpec] = { - val shuffleDep = shuffleStage.shuffle.shuffleDependency - val numReducers = shuffleDep.partitioner.numPartitions + val numMappers = shuffleStage.shuffle.getNumMappers + val numReducers = shuffleStage.shuffle.getNumReducers val expectedParallelism = advisoryParallelism.getOrElse(numReducers) - val numMappers = shuffleDep.rdd.getNumPartitions val splitPoints = if (numMappers == 0) { Seq.empty } else { @@ -113,6 +112,9 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { } plan match { + // skip the top-level exchange operator + case s: Exchange => + s.withNewChildren(s.children.map(apply)) case s: SparkPlan if canUseLocalShuffleReader(s) => createLocalReader(s) case s: SparkPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 627f0600f2383..3a7fcf752fba7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -202,7 +202,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { val leftParts = if (isLeftSkew && !isLeftCoalesced) { val reducerId = leftPartSpec.startReducerIndex val skewSpecs = createSkewPartitionSpecs( - left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize) + left.shuffleStage.shuffle.shuffleId, reducerId, leftTargetSize) if (skewSpecs.isDefined) { logDebug(s"Left side partition $partitionIndex " + s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " + @@ -218,7 +218,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { val rightParts = if (isRightSkew && !isRightCoalesced) { val reducerId = rightPartSpec.startReducerIndex val skewSpecs = createSkewPartitionSpecs( - right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize) + right.shuffleStage.shuffle.shuffleId, reducerId, rightTargetSize) if (skewSpecs.isDefined) { logDebug(s"Right side partition $partitionIndex " + s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index 4e83b4344fbf0..c5d4ece94d421 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ThreadUtils /** @@ -107,9 +108,11 @@ abstract class QueryStageExec extends LeafExecNode { override def executeTake(n: Int): Array[InternalRow] = plan.executeTake(n) override def executeTail(n: Int): Array[InternalRow] = plan.executeTail(n) override def executeToIterator(): Iterator[InternalRow] = plan.executeToIterator() + override def supportsColumnar: Boolean = plan.supportsColumnar protected override def doPrepare(): Unit = plan.prepare() protected override def doExecute(): RDD[InternalRow] = plan.execute() + override def doExecuteColumnar(): RDD[ColumnarBatch] = plan.executeColumnar() override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast() override def doCanonicalize(): SparkPlan = plan.canonicalized @@ -138,15 +141,16 @@ abstract class QueryStageExec extends LeafExecNode { } /** - * A shuffle query stage whose child is a [[ShuffleExchangeExec]] or [[ReusedExchangeExec]]. + * A shuffle query stage whose child is a [[ShuffleExchange]] or a [[ReusedExchangeExec]] wrapping + * a [[ShuffleExchange]]. */ case class ShuffleQueryStageExec( override val id: Int, override val plan: SparkPlan) extends QueryStageExec { @transient val shuffle = plan match { - case s: ShuffleExchangeExec => s - case ReusedExchangeExec(_, s: ShuffleExchangeExec) => s + case s: ShuffleExchange => s + case ReusedExchangeExec(_, s: ShuffleExchange) => s case _ => throw new IllegalStateException("wrong plan for shuffle stage:\n " + plan.treeString) } @@ -184,15 +188,16 @@ case class ShuffleQueryStageExec( } /** - * A broadcast query stage whose child is a [[BroadcastExchangeExec]] or [[ReusedExchangeExec]]. + * A broadcast query stage whose child is a [[BroadcastExchange]] or a [[ReusedExchangeExec]] + * wrapping a [[BroadcastExchange]]. */ case class BroadcastQueryStageExec( override val id: Int, override val plan: SparkPlan) extends QueryStageExec { @transient val broadcast = plan match { - case b: BroadcastExchangeExec => b - case ReusedExchangeExec(_, b: BroadcastExchangeExec) => b + case b: BroadcastExchange => b + case ReusedExchangeExec(_, b: BroadcastExchange) => b case _ => throw new IllegalStateException("wrong plan for broadcast stage:\n " + plan.treeString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index d35bbe9b8adc0..fdd262b2e4e2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -34,16 +34,28 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.HashedRelation import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{SparkFatalException, ThreadUtils} +/** + * Base class for implementations of broadcast exchanges. This was added to enable plugins to + * provide columnar implementations of broadcast exchanges when Adaptive Query Execution is + * enabled. + */ +abstract class BroadcastExchange extends Exchange { + private[sql] def runId: UUID + private[sql] def relationFuture: Future[broadcast.Broadcast[Any]] + def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] +} + /** * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of * a transformed SparkPlan. */ case class BroadcastExchangeExec( mode: BroadcastMode, - child: SparkPlan) extends Exchange { + child: SparkPlan) extends BroadcastExchange { import BroadcastExchangeExec._ private[sql] val runId: UUID = UUID.randomUUID @@ -156,6 +168,11 @@ case class BroadcastExchangeExec( "BroadcastExchange does not support the execute() code path.") } + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { + throw new UnsupportedOperationException( + "BroadcastExchange does not support the executeColumnar() code path.") + } + override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { try { relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index b06742e8470c7..dc83c3bbddcbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -37,16 +37,30 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.MutablePair import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} +/** + * Base class for implementations of shuffle exchanges. This was added to enable plugins to + * provide columnar implementations of shuffle exchanges when Adaptive Query Execution is + * enabled. + */ +abstract class ShuffleExchange extends Exchange { + def shuffleId: Int + def getNumMappers: Int + def getNumReducers: Int + def canChangeNumPartitions: Boolean + def mapOutputStatisticsFuture: Future[MapOutputStatistics] +} + /** * Performs a shuffle that will result in the desired partitioning. */ case class ShuffleExchangeExec( override val outputPartitioning: Partitioning, child: SparkPlan, - canChangeNumPartitions: Boolean = true) extends Exchange { + canChangeNumPartitions: Boolean = true) extends ShuffleExchange { private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) @@ -63,6 +77,12 @@ case class ShuffleExchangeExec( @transient lazy val inputRDD: RDD[InternalRow] = child.execute() + override def shuffleId: Int = shuffleDependency.shuffleId + + override def getNumMappers: Int = shuffleDependency.rdd.getNumPartitions + + override def getNumReducers: Int = shuffleDependency.partitioner.numPartitions + // 'mapOutputStatisticsFuture' is only needed when enable AQE. @transient lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { if (inputRDD.getNumPartitions == 0) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index d9c90c7dbd085..c66d62c788b3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -18,15 +18,20 @@ package org.apache.spark.sql import java.util.Locale -import org.apache.spark.{SparkFunSuite, TaskContext} +import scala.concurrent.Future + +import org.apache.spark.{MapOutputStatistics, SparkFunSuite, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, UnresolvedHint} +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE @@ -145,33 +150,56 @@ class SparkSessionExtensionSuite extends SparkFunSuite { } } - test("inject columnar") { + test("inject columnar AQE on") { + testInjectColumnar(true) + } + + test("inject columnar AQE off") { + testInjectColumnar(false) + } + + private def testInjectColumnar(adaptiveEnabled: Boolean) { + + def collectPlanSteps(plan: SparkPlan): Seq[Int] = plan match { + case a: AdaptiveSparkPlanExec => + assert(a.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true")) + collectPlanSteps(a.executedPlan) + case _ => plan.collect { + case _: ReplacedRowToColumnarExec => 1 + case _: ColumnarProjectExec => 10 + case _: ColumnarToRowExec => 100 + case s: QueryStageExec => collectPlanSteps(s.plan).sum + case _: MyShuffleExchangeExec => 1000 + case _: MyBroadcastExchangeExec => 10000 + } + } + val extensions = create { extensions => extensions.injectColumnar(session => MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) } withSession(extensions) { session => - // The ApplyColumnarRulesAndInsertTransitions rule is not applied when enable AQE - session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, false) + session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, adaptiveEnabled) assert(session.sessionState.columnarRules.contains( MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))) import session.sqlContext.implicits._ - // repartitioning avoids having the add operation pushed up into the LocalTableScan - val data = Seq((100L), (200L), (300L)).toDF("vals").repartition(1) - val df = data.selectExpr("vals + 1") + // perform a join to inject a broadcast exchange + val left = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("l1", "l2") + val right = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("r1", "r2") + val data = left.join(right, $"l1" === $"r1") + // repartitioning avoids having the add operation pushed up into the LocalTableScan + .repartition(1) + val df = data.selectExpr("l2 + r2") + // execute the plan so that the final adaptive plan is available when AQE is on + df.collect() // Verify that both pre and post processing of the plan worked. - val found = df.queryExecution.executedPlan.collect { - case rep: ReplacedRowToColumnarExec => 1 - case proj: ColumnarProjectExec => 10 - case c2r: ColumnarToRowExec => 100 - }.sum - assert(found == 111) - + val found = collectPlanSteps(df.queryExecution.executedPlan).sum + assert(found == 11121) // Verify that we get back the expected, wrong, result val result = df.collect() - assert(result(0).getLong(0) == 102L) // Check that broken columnar Add was used. - assert(result(1).getLong(0) == 202L) - assert(result(2).getLong(0) == 302L) + assert(result(0).getLong(0) == 101L) // Check that broken columnar Add was used. + assert(result(1).getLong(0) == 201L) + assert(result(2).getLong(0) == 301L) } } @@ -671,6 +699,15 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] { def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = try { plan match { + case e: ShuffleExchangeExec => + // note that this is not actually columnar but demonstrates that exchanges can + // be replaced, particularly when adaptive query is enabled + val replaced = e.withNewChildren(e.children.map(replaceWithColumnarPlan)) + MyShuffleExchangeExec(replaced.asInstanceOf[ShuffleExchangeExec]) + case e: BroadcastExchangeExec => + // note that this is not actually columnar but demonstrates that exchanges can + // be replaced, particularly when adaptive query is enabled + new MyBroadcastExchangeExec(e.mode, e.child) case plan: ProjectExec => new ColumnarProjectExec(plan.projectList.map((exp) => replaceWithColumnarExpression(exp).asInstanceOf[NamedExpression]), @@ -689,6 +726,37 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = replaceWithColumnarPlan(plan) } +/** + * Custom Exchange used in tests to demonstrate that shuffles can be replaced regardless of + * whether adaptive query is enabled. + */ +case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchange { + override def shuffleId: Int = delegate.shuffleId + override def getNumMappers: Int = delegate.getNumMappers + override def getNumReducers: Int = delegate.getNumReducers + override def canChangeNumPartitions: Boolean = delegate.canChangeNumPartitions + override def mapOutputStatisticsFuture: Future[MapOutputStatistics] = + delegate.mapOutputStatisticsFuture + override def child: SparkPlan = delegate.child + override protected def doExecute(): RDD[InternalRow] = delegate.execute() +} + +/** + * Custom Exchange used in tests to demonstrate that broadcasts can be replaced regardless of + * whether adaptive query is enabled. + * + * Note that extending a Spark case class is not recommended, but this was the easiest way to + * implement these tests. + */ +class MyBroadcastExchangeExec(mode: BroadcastMode, + child: SparkPlan) extends BroadcastExchangeExec(mode, child) { + override def equals(o: Any): Boolean = o match { + case o: MyBroadcastExchangeExec => mode.equals(o.mode) && child.equals(o.child) + case _ => false + } + override def hashCode(): Int = mode.hashCode() + child.hashCode() +} + class ReplacedRowToColumnarExec(override val child: SparkPlan) extends RowToColumnarExec(child) {