diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 34db0a334f67f..b160b8ac2ed68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -100,7 +100,12 @@ case class AdaptiveSparkPlanExec( // The following two rules need to make use of 'CustomShuffleReaderExec.partitionSpecs' // added by `CoalesceShufflePartitions`. So they must be executed after it. OptimizeSkewedJoin(conf), - OptimizeLocalShuffleReader(conf), + OptimizeLocalShuffleReader(conf) + ) + + // A list of physical optimizer rules to be applied right after a new stage is created. The input + // plan to these rules has exchange as its root node. + @transient private val postStageCreationRules = Seq( ApplyColumnarRulesAndInsertTransitions(conf, context.session.sessionState.columnarRules), CollapseCodegenStages(conf) ) @@ -227,7 +232,8 @@ case class AdaptiveSparkPlanExec( } // Run the final plan when there's no more unfinished stages. - currentPhysicalPlan = applyPhysicalRules(result.newPlan, queryStageOptimizerRules) + currentPhysicalPlan = applyPhysicalRules( + result.newPlan, queryStageOptimizerRules ++ postStageCreationRules) isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) currentPhysicalPlan @@ -376,10 +382,22 @@ case class AdaptiveSparkPlanExec( private def newQueryStage(e: Exchange): QueryStageExec = { val optimizedPlan = applyPhysicalRules(e.child, queryStageOptimizerRules) val queryStage = e match { - case s: ShuffleExchangeExec => - ShuffleQueryStageExec(currentStageId, s.copy(child = optimizedPlan)) - case b: BroadcastExchangeExec => - BroadcastQueryStageExec(currentStageId, b.copy(child = optimizedPlan)) + case s: ShuffleExchangeLike => + val newShuffle = applyPhysicalRules( + s.withNewChildren(Seq(optimizedPlan)), postStageCreationRules) + if (!newShuffle.isInstanceOf[ShuffleExchangeLike]) { + throw new IllegalStateException( + "Custom columnar rules cannot transform shuffle node to something else.") + } + ShuffleQueryStageExec(currentStageId, newShuffle) + case b: BroadcastExchangeLike => + val newBroadcast = applyPhysicalRules( + b.withNewChildren(Seq(optimizedPlan)), postStageCreationRules) + if (!newBroadcast.isInstanceOf[BroadcastExchangeLike]) { + throw new IllegalStateException( + "Custom columnar rules cannot transform broadcast node to something else.") + } + BroadcastQueryStageExec(currentStageId, newBroadcast) } currentStageId += 1 setLogicalLinkForNewQueryStage(queryStage, e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala index af18ee065aa86..49a4c25fa637f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.adaptive -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD @@ -25,8 +24,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.vectorized.ColumnarBatch /** @@ -45,6 +45,8 @@ case class CustomShuffleReaderExec private( assert(partitionSpecs.forall(_.isInstanceOf[PartialMapperPartitionSpec])) } + override def supportsColumnar: Boolean = child.supportsColumnar + override def output: Seq[Attribute] = child.output override lazy val outputPartitioning: Partitioning = { // If it is a local shuffle reader with one mapper per task, then the output partitioning is @@ -55,9 +57,9 @@ case class CustomShuffleReaderExec private( partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size == partitionSpecs.length) { child match { - case ShuffleQueryStageExec(_, s: ShuffleExchangeExec) => + case ShuffleQueryStageExec(_, s: ShuffleExchangeLike) => s.child.outputPartitioning - case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeExec)) => + case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeLike)) => s.child.outputPartitioning match { case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning] case other => other @@ -176,18 +178,21 @@ case class CustomShuffleReaderExec private( } } - private lazy val cachedShuffleRDD: RDD[InternalRow] = { + private lazy val shuffleRDD: RDD[_] = { sendDriverMetrics() shuffleStage.map { stage => - new ShuffledRowRDD( - stage.shuffle.shuffleDependency, stage.shuffle.readMetrics, partitionSpecs.toArray) + stage.shuffle.getShuffleRDD(partitionSpecs.toArray) }.getOrElse { throw new IllegalStateException("operating on canonicalized plan") } } override protected def doExecute(): RDD[InternalRow] = { - cachedShuffleRDD + shuffleRDD.asInstanceOf[RDD[InternalRow]] + } + + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { + shuffleRDD.asInstanceOf[RDD[ColumnarBatch]] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index 3620f27058af2..45fb36420e770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -78,10 +78,9 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { private def getPartitionSpecs( shuffleStage: ShuffleQueryStageExec, advisoryParallelism: Option[Int]): Seq[ShufflePartitionSpec] = { - val shuffleDep = shuffleStage.shuffle.shuffleDependency - val numReducers = shuffleDep.partitioner.numPartitions + val numMappers = shuffleStage.shuffle.numMappers + val numReducers = shuffleStage.shuffle.numPartitions val expectedParallelism = advisoryParallelism.getOrElse(numReducers) - val numMappers = shuffleDep.rdd.getNumPartitions val splitPoints = if (numMappers == 0) { Seq.empty } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 627f0600f2383..a85b188727ba4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -202,7 +202,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { val leftParts = if (isLeftSkew && !isLeftCoalesced) { val reducerId = leftPartSpec.startReducerIndex val skewSpecs = createSkewPartitionSpecs( - left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize) + left.mapStats.shuffleId, reducerId, leftTargetSize) if (skewSpecs.isDefined) { logDebug(s"Left side partition $partitionIndex " + s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " + @@ -218,7 +218,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { val rightParts = if (isRightSkew && !isRightCoalesced) { val reducerId = rightPartSpec.startReducerIndex val skewSpecs = createSkewPartitionSpecs( - right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize) + right.mapStats.shuffleId, reducerId, rightTargetSize) if (skewSpecs.isDefined) { logDebug(s"Right side partition $partitionIndex " + s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index 4e83b4344fbf0..0927ef5b0b3c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ThreadUtils /** @@ -81,17 +82,19 @@ abstract class QueryStageExec extends LeafExecNode { def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec + /** + * Returns the runtime statistics after stage materialization. + */ + def getRuntimeStatistics: Statistics + /** * Compute the statistics of the query stage if executed, otherwise None. */ def computeStats(): Option[Statistics] = resultOption.get().map { _ => - // Metrics `dataSize` are available in both `ShuffleExchangeExec` and `BroadcastExchangeExec`. - val exchange = plan match { - case r: ReusedExchangeExec => r.child - case e: Exchange => e - case _ => throw new IllegalStateException("wrong plan for query stage:\n " + plan.treeString) - } - Statistics(sizeInBytes = exchange.metrics("dataSize").value) + val runtimeStats = getRuntimeStatistics + val dataSize = runtimeStats.sizeInBytes.max(0) + val numOutputRows = runtimeStats.rowCount.map(_.max(0)) + Statistics(dataSize, numOutputRows) } @transient @@ -110,6 +113,8 @@ abstract class QueryStageExec extends LeafExecNode { protected override def doPrepare(): Unit = plan.prepare() protected override def doExecute(): RDD[InternalRow] = plan.execute() + override def supportsColumnar: Boolean = plan.supportsColumnar + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = plan.executeColumnar() override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast() override def doCanonicalize(): SparkPlan = plan.canonicalized @@ -138,15 +143,15 @@ abstract class QueryStageExec extends LeafExecNode { } /** - * A shuffle query stage whose child is a [[ShuffleExchangeExec]] or [[ReusedExchangeExec]]. + * A shuffle query stage whose child is a [[ShuffleExchangeLike]] or [[ReusedExchangeExec]]. */ case class ShuffleQueryStageExec( override val id: Int, override val plan: SparkPlan) extends QueryStageExec { @transient val shuffle = plan match { - case s: ShuffleExchangeExec => s - case ReusedExchangeExec(_, s: ShuffleExchangeExec) => s + case s: ShuffleExchangeLike => s + case ReusedExchangeExec(_, s: ShuffleExchangeLike) => s case _ => throw new IllegalStateException("wrong plan for shuffle stage:\n " + plan.treeString) } @@ -177,22 +182,24 @@ case class ShuffleQueryStageExec( * this method returns None, as there is no map statistics. */ def mapStats: Option[MapOutputStatistics] = { - assert(resultOption.get().isDefined, "ShuffleQueryStageExec should already be ready") + assert(resultOption.get().isDefined, s"${getClass.getSimpleName} should already be ready") val stats = resultOption.get().get.asInstanceOf[MapOutputStatistics] Option(stats) } + + override def getRuntimeStatistics: Statistics = shuffle.runtimeStatistics } /** - * A broadcast query stage whose child is a [[BroadcastExchangeExec]] or [[ReusedExchangeExec]]. + * A broadcast query stage whose child is a [[BroadcastExchangeLike]] or [[ReusedExchangeExec]]. */ case class BroadcastQueryStageExec( override val id: Int, override val plan: SparkPlan) extends QueryStageExec { @transient val broadcast = plan match { - case b: BroadcastExchangeExec => b - case ReusedExchangeExec(_, b: BroadcastExchangeExec) => b + case b: BroadcastExchangeLike => b + case ReusedExchangeExec(_, b: BroadcastExchangeLike) => b case _ => throw new IllegalStateException("wrong plan for broadcast stage:\n " + plan.treeString) } @@ -231,6 +238,8 @@ case class BroadcastQueryStageExec( broadcast.relationFuture.cancel(true) } } + + override def getRuntimeStatistics: Statistics = broadcast.runtimeStatistics } object BroadcastQueryStageExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala index 67cd720bb5b33..cdc57dbc7dcc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} /** * A simple implementation of [[Cost]], which takes a number of [[Long]] as the cost value. @@ -35,13 +35,13 @@ case class SimpleCost(value: Long) extends Cost { /** * A simple implementation of [[CostEvaluator]], which counts the number of - * [[ShuffleExchangeExec]] nodes in the plan. + * [[ShuffleExchangeLike]] nodes in the plan. */ object SimpleCostEvaluator extends CostEvaluator { override def evaluateCost(plan: SparkPlan): Cost = { val cost = plan.collect { - case s: ShuffleExchangeExec => s + case s: ShuffleExchangeLike => s }.size SimpleCost(cost) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index d35bbe9b8adc0..6d8d37022ea42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.HashedRelation @@ -37,16 +38,43 @@ import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{SparkFatalException, ThreadUtils} +/** + * Common trait for all broadcast exchange implementations to facilitate pattern matching. + */ +trait BroadcastExchangeLike extends Exchange { + + /** + * The broadcast job group ID + */ + def runId: UUID = UUID.randomUUID + + /** + * The asynchronous job that prepares the broadcast relation. + */ + def relationFuture: Future[broadcast.Broadcast[Any]] + + /** + * For registering callbacks on `relationFuture`. + * Note that calling this method may not start the execution of broadcast job. + */ + def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] + + /** + * Returns the runtime statistics after broadcast materialization. + */ + def runtimeStatistics: Statistics +} + /** * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of * a transformed SparkPlan. */ case class BroadcastExchangeExec( mode: BroadcastMode, - child: SparkPlan) extends Exchange { + child: SparkPlan) extends BroadcastExchangeLike { import BroadcastExchangeExec._ - private[sql] val runId: UUID = UUID.randomUUID + override val runId: UUID = UUID.randomUUID override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), @@ -60,21 +88,23 @@ case class BroadcastExchangeExec( BroadcastExchangeExec(mode.canonicalized, child.canonicalized) } + override def runtimeStatistics: Statistics = { + val dataSize = metrics("dataSize").value + Statistics(dataSize) + } + @transient private lazy val promise = Promise[broadcast.Broadcast[Any]]() - /** - * For registering callbacks on `relationFuture`. - * Note that calling this field will not start the execution of broadcast job. - */ @transient - lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = promise.future + override lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = + promise.future @transient private val timeout: Long = SQLConf.get.broadcastTimeout @transient - private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( sqlContext.sparkSession, BroadcastExchangeExec.executionContext) { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index b06742e8470c7..30c9f0ae1282d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -30,8 +30,9 @@ import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProces import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Divide, Literal, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} @@ -40,13 +41,49 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.util.MutablePair import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} +/** + * Common trait for all shuffle exchange implementations to facilitate pattern matching. + */ +trait ShuffleExchangeLike extends Exchange { + + /** + * Returns the number of mappers of this shuffle. + */ + def numMappers: Int + + /** + * Returns the shuffle partition number. + */ + def numPartitions: Int + + /** + * Returns whether the shuffle partition number can be changed. + */ + def canChangeNumPartitions: Boolean + + /** + * The asynchronous job that materializes the shuffle. + */ + def mapOutputStatisticsFuture: Future[MapOutputStatistics] + + /** + * Returns the shuffle RDD with specified partition specs. + */ + def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] + + /** + * Returns the runtime statistics after shuffle materialization. + */ + def runtimeStatistics: Statistics +} + /** * Performs a shuffle that will result in the desired partitioning. */ case class ShuffleExchangeExec( override val outputPartitioning: Partitioning, child: SparkPlan, - canChangeNumPartitions: Boolean = true) extends Exchange { + canChangeNumPartitions: Boolean = true) extends ShuffleExchangeLike { private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) @@ -64,7 +101,7 @@ case class ShuffleExchangeExec( @transient lazy val inputRDD: RDD[InternalRow] = child.execute() // 'mapOutputStatisticsFuture' is only needed when enable AQE. - @transient lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { + @transient override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { if (inputRDD.getNumPartitions == 0) { Future.successful(null) } else { @@ -72,6 +109,20 @@ case class ShuffleExchangeExec( } } + override def numMappers: Int = shuffleDependency.rdd.getNumPartitions + + override def numPartitions: Int = shuffleDependency.partitioner.numPartitions + + override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[InternalRow] = { + new ShuffledRowRDD(shuffleDependency, readMetrics, partitionSpecs) + } + + override def runtimeStatistics: Statistics = { + val dataSize = metrics("dataSize").value + val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value + Statistics(dataSize, Some(rowCount)) + } + /** * A [[ShuffleDependency]] that will partition rows of its child based on * the partitioning scheme defined in `newPartitioning`. Those partitions of diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 7773ac71c4954..bfa60cf7dfd78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, LocalLimitExec, QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.Utils @@ -118,7 +118,7 @@ class IncrementalExecution( case s: StatefulOperator => statefulOpFound = true - case e: ShuffleExchangeExec => + case e: ShuffleExchangeLike => // Don't search recursively any further as any child stateful operator as we // are only looking for stateful subplans that this plan has narrow dependencies on. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 44e784de5164f..e5e8bc6917799 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -16,19 +16,24 @@ */ package org.apache.spark.sql -import java.util.Locale +import java.util.{Locale, UUID} -import org.apache.spark.{SparkFunSuite, TaskContext} +import scala.concurrent.Future + +import org.apache.spark.{MapOutputStatistics, SparkFunSuite, TaskContext} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, UnresolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Statistics, UnresolvedHint} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE @@ -169,33 +174,61 @@ class SparkSessionExtensionSuite extends SparkFunSuite { } } - test("inject columnar") { + test("inject columnar AQE on") { + testInjectColumnar(true) + } + + test("inject columnar AQE off") { + testInjectColumnar(false) + } + + private def testInjectColumnar(enableAQE: Boolean): Unit = { + def collectPlanSteps(plan: SparkPlan): Seq[Int] = plan match { + case a: AdaptiveSparkPlanExec => + assert(a.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true")) + collectPlanSteps(a.executedPlan) + case _ => plan.collect { + case _: ReplacedRowToColumnarExec => 1 + case _: ColumnarProjectExec => 10 + case _: ColumnarToRowExec => 100 + case s: QueryStageExec => collectPlanSteps(s.plan).sum + case _: MyShuffleExchangeExec => 1000 + case _: MyBroadcastExchangeExec => 10000 + } + } + val extensions = create { extensions => extensions.injectColumnar(session => MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) } withSession(extensions) { session => - // The ApplyColumnarRulesAndInsertTransitions rule is not applied when enable AQE - session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, false) + session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, enableAQE) assert(session.sessionState.columnarRules.contains( MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))) import session.sqlContext.implicits._ - // repartitioning avoids having the add operation pushed up into the LocalTableScan - val data = Seq((100L), (200L), (300L)).toDF("vals").repartition(1) - val df = data.selectExpr("vals + 1") - // Verify that both pre and post processing of the plan worked. - val found = df.queryExecution.executedPlan.collect { - case rep: ReplacedRowToColumnarExec => 1 - case proj: ColumnarProjectExec => 10 - case c2r: ColumnarToRowExec => 100 - }.sum - assert(found == 111) + // perform a join to inject a broadcast exchange + val left = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("l1", "l2") + val right = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("r1", "r2") + val data = left.join(right, $"l1" === $"r1") + // repartitioning avoids having the add operation pushed up into the LocalTableScan + .repartition(1) + val df = data.selectExpr("l2 + r2") + // execute the plan so that the final adaptive plan is available when AQE is on + df.collect() + val found = collectPlanSteps(df.queryExecution.executedPlan).sum + // 1 MyBroadcastExchangeExec + // 1 MyShuffleExchangeExec + // 1 ColumnarToRowExec + // 2 ColumnarProjectExec + // 1 ReplacedRowToColumnarExec + // so 11121 is expected. + assert(found == 11121) // Verify that we get back the expected, wrong, result val result = df.collect() - assert(result(0).getLong(0) == 102L) // Check that broken columnar Add was used. - assert(result(1).getLong(0) == 202L) - assert(result(2).getLong(0) == 302L) + assert(result(0).getLong(0) == 101L) // Check that broken columnar Add was used. + assert(result(1).getLong(0) == 201L) + assert(result(2).getLong(0) == 301L) } } @@ -695,6 +728,16 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] { def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = try { plan match { + case e: ShuffleExchangeExec => + // note that this is not actually columnar but demonstrates that exchanges can + // be replaced. + val replaced = e.withNewChildren(e.children.map(replaceWithColumnarPlan)) + MyShuffleExchangeExec(replaced.asInstanceOf[ShuffleExchangeExec]) + case e: BroadcastExchangeExec => + // note that this is not actually columnar but demonstrates that exchanges can + // be replaced. + val replaced = e.withNewChildren(e.children.map(replaceWithColumnarPlan)) + MyBroadcastExchangeExec(replaced.asInstanceOf[BroadcastExchangeExec]) case plan: ProjectExec => new ColumnarProjectExec(plan.projectList.map((exp) => replaceWithColumnarExpression(exp).asInstanceOf[NamedExpression]), @@ -713,6 +756,41 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = replaceWithColumnarPlan(plan) } +/** + * Custom Exchange used in tests to demonstrate that shuffles can be replaced regardless of + * whether AQE is enabled. + */ +case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchangeLike { + override def numMappers: Int = delegate.numMappers + override def numPartitions: Int = delegate.numPartitions + override def canChangeNumPartitions: Boolean = delegate.canChangeNumPartitions + override def mapOutputStatisticsFuture: Future[MapOutputStatistics] = + delegate.mapOutputStatisticsFuture + override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] = + delegate.getShuffleRDD(partitionSpecs) + override def runtimeStatistics: Statistics = delegate.runtimeStatistics + override def child: SparkPlan = delegate.child + override protected def doExecute(): RDD[InternalRow] = delegate.execute() + override def outputPartitioning: Partitioning = delegate.outputPartitioning +} + +/** + * Custom Exchange used in tests to demonstrate that broadcasts can be replaced regardless of + * whether AQE is enabled. + */ +case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends BroadcastExchangeLike { + override def runId: UUID = delegate.runId + override def relationFuture: java.util.concurrent.Future[Broadcast[Any]] = + delegate.relationFuture + override def completionFuture: Future[Broadcast[Any]] = delegate.completionFuture + override def runtimeStatistics: Statistics = delegate.runtimeStatistics + override def child: SparkPlan = delegate.child + override protected def doPrepare(): Unit = delegate.prepare() + override protected def doExecute(): RDD[InternalRow] = delegate.execute() + override def doExecuteBroadcast[T](): Broadcast[T] = delegate.executeBroadcast() + override def outputPartitioning: Partitioning = delegate.outputPartitioning +} + class ReplacedRowToColumnarExec(override val child: SparkPlan) extends RowToColumnarExec(child) {