Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -374,12 +374,14 @@ case class AdaptiveSparkPlanExec(
}

private def newQueryStage(e: Exchange): QueryStageExec = {
val optimizedPlan = applyPhysicalRules(e.child, queryStageOptimizerRules)
// apply optimizer rules to the Exchange node and its children, allowing plugins to be
// able to replace the Exchange node itself
val optimizedPlan = applyPhysicalRules(e, queryStageOptimizerRules)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we guarantee the top node is still an Exchange after applying physical rules?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's safer to use s.withNewChildren(optimizedChildPlan) than adding special handling in those physical rules: https://github.com/apache/spark/pull/29134/files#diff-a30c7a6fcdcdd13e57135fd04d05f3b7R115-R117

That saves you the trouble of worrying about certain assumptions being broken in an arbitrary rule.

val queryStage = e match {
case s: ShuffleExchangeExec =>
ShuffleQueryStageExec(currentStageId, s.copy(child = optimizedPlan))
case b: BroadcastExchangeExec =>
BroadcastQueryStageExec(currentStageId, b.copy(child = optimizedPlan))
case _: ShuffleExchange =>
ShuffleQueryStageExec(currentStageId, optimizedPlan)
case _: BroadcastExchange =>
BroadcastQueryStageExec(currentStageId, optimizedPlan)
}
currentStageId += 1
setLogicalLinkForNewQueryStage(queryStage, e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}


Expand Down Expand Up @@ -55,9 +55,9 @@ case class CustomShuffleReaderExec private(
partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size ==
partitionSpecs.length) {
child match {
case ShuffleQueryStageExec(_, s: ShuffleExchangeExec) =>
case ShuffleQueryStageExec(_, s: ShuffleExchange) =>
s.child.outputPartitioning
case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeExec)) =>
case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchange)) =>
s.child.outputPartitioning match {
case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning]
case other => other
Expand Down Expand Up @@ -180,8 +180,10 @@ case class CustomShuffleReaderExec private(
sendDriverMetrics()

shuffleStage.map { stage =>
val shuffleExchangeExec = stage.shuffle.asInstanceOf[ShuffleExchangeExec]
new ShuffledRowRDD(
stage.shuffle.shuffleDependency, stage.shuffle.readMetrics, partitionSpecs.toArray)
shuffleExchangeExec.shuffleDependency,
shuffleExchangeExec.readMetrics, partitionSpecs.toArray)
}.getOrElse {
throw new IllegalStateException("operating on canonicalized plan")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, Exchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.internal.SQLConf

Expand Down Expand Up @@ -78,10 +78,9 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {
private def getPartitionSpecs(
shuffleStage: ShuffleQueryStageExec,
advisoryParallelism: Option[Int]): Seq[ShufflePartitionSpec] = {
val shuffleDep = shuffleStage.shuffle.shuffleDependency
val numReducers = shuffleDep.partitioner.numPartitions
val numMappers = shuffleStage.shuffle.getNumMappers
val numReducers = shuffleStage.shuffle.getNumReducers
val expectedParallelism = advisoryParallelism.getOrElse(numReducers)
val numMappers = shuffleDep.rdd.getNumPartitions
val splitPoints = if (numMappers == 0) {
Seq.empty
} else {
Expand Down Expand Up @@ -113,6 +112,9 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {
}

plan match {
// skip the top-level exchange operator
case s: Exchange =>
s.withNewChildren(s.children.map(apply))
case s: SparkPlan if canUseLocalShuffleReader(s) =>
createLocalReader(s)
case s: SparkPlan =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
val leftParts = if (isLeftSkew && !isLeftCoalesced) {
val reducerId = leftPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Get shuffleId from mapStats.

left.shuffleStage.shuffle.shuffleId, reducerId, leftTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Left side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " +
Expand All @@ -218,7 +218,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
val rightParts = if (isRightSkew && !isRightCoalesced) {
val reducerId = rightPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
right.shuffleStage.shuffle.shuffleId, reducerId, rightTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Right side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ThreadUtils

/**
Expand Down Expand Up @@ -107,9 +108,11 @@ abstract class QueryStageExec extends LeafExecNode {
override def executeTake(n: Int): Array[InternalRow] = plan.executeTake(n)
override def executeTail(n: Int): Array[InternalRow] = plan.executeTail(n)
override def executeToIterator(): Iterator[InternalRow] = plan.executeToIterator()
override def supportsColumnar: Boolean = plan.supportsColumnar

protected override def doPrepare(): Unit = plan.prepare()
protected override def doExecute(): RDD[InternalRow] = plan.execute()
override def doExecuteColumnar(): RDD[ColumnarBatch] = plan.executeColumnar()
override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast()
override def doCanonicalize(): SparkPlan = plan.canonicalized

Expand Down Expand Up @@ -138,15 +141,16 @@ abstract class QueryStageExec extends LeafExecNode {
}

/**
* A shuffle query stage whose child is a [[ShuffleExchangeExec]] or [[ReusedExchangeExec]].
* A shuffle query stage whose child is a [[ShuffleExchange]] or a [[ReusedExchangeExec]] wrapping
* a [[ShuffleExchange]].
*/
case class ShuffleQueryStageExec(
override val id: Int,
override val plan: SparkPlan) extends QueryStageExec {

@transient val shuffle = plan match {
case s: ShuffleExchangeExec => s
case ReusedExchangeExec(_, s: ShuffleExchangeExec) => s
case s: ShuffleExchange => s
case ReusedExchangeExec(_, s: ShuffleExchange) => s
case _ =>
throw new IllegalStateException("wrong plan for shuffle stage:\n " + plan.treeString)
}
Expand Down Expand Up @@ -184,15 +188,16 @@ case class ShuffleQueryStageExec(
}

/**
* A broadcast query stage whose child is a [[BroadcastExchangeExec]] or [[ReusedExchangeExec]].
* A broadcast query stage whose child is a [[BroadcastExchange]] or a [[ReusedExchangeExec]]
* wrapping a [[BroadcastExchange]].
*/
case class BroadcastQueryStageExec(
override val id: Int,
override val plan: SparkPlan) extends QueryStageExec {

@transient val broadcast = plan match {
case b: BroadcastExchangeExec => b
case ReusedExchangeExec(_, b: BroadcastExchangeExec) => b
case b: BroadcastExchange => b
case ReusedExchangeExec(_, b: BroadcastExchange) => b
case _ =>
throw new IllegalStateException("wrong plan for broadcast stage:\n " + plan.treeString)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,28 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.HashedRelation
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.util.{SparkFatalException, ThreadUtils}

/**
* Base class for implementations of broadcast exchanges. This was added to enable plugins to
* provide columnar implementations of broadcast exchanges when Adaptive Query Execution is
* enabled.
*/
abstract class BroadcastExchange extends Exchange {
private[sql] def runId: UUID
private[sql] def relationFuture: Future[broadcast.Broadcast[Any]]
def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]
}

/**
* A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of
* a transformed SparkPlan.
*/
case class BroadcastExchangeExec(
mode: BroadcastMode,
child: SparkPlan) extends Exchange {
child: SparkPlan) extends BroadcastExchange {
import BroadcastExchangeExec._

private[sql] val runId: UUID = UUID.randomUUID
Expand Down Expand Up @@ -156,6 +168,11 @@ case class BroadcastExchangeExec(
"BroadcastExchange does not support the execute() code path.")
}

override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
throw new UnsupportedOperationException(
"BroadcastExchange does not support the executeColumnar() code path.")
}

override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
try {
relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,30 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.MutablePair
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}

/**
* Base class for implementations of shuffle exchanges. This was added to enable plugins to
* provide columnar implementations of shuffle exchanges when Adaptive Query Execution is
* enabled.
*/
abstract class ShuffleExchange extends Exchange {
def shuffleId: Int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is available in mapOutputStats

def getNumMappers: Int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can get this through MapOutputTracker.

def getNumReducers: Int
def canChangeNumPartitions: Boolean
def mapOutputStatisticsFuture: Future[MapOutputStatistics]
}

/**
* Performs a shuffle that will result in the desired partitioning.
*/
case class ShuffleExchangeExec(
override val outputPartitioning: Partitioning,
child: SparkPlan,
canChangeNumPartitions: Boolean = true) extends Exchange {
canChangeNumPartitions: Boolean = true) extends ShuffleExchange {

private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
Expand All @@ -63,6 +77,12 @@ case class ShuffleExchangeExec(

@transient lazy val inputRDD: RDD[InternalRow] = child.execute()

override def shuffleId: Int = shuffleDependency.shuffleId

override def getNumMappers: Int = shuffleDependency.rdd.getNumPartitions

override def getNumReducers: Int = shuffleDependency.partitioner.numPartitions

// 'mapOutputStatisticsFuture' is only needed when enable AQE.
@transient lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
if (inputRDD.getNumPartitions == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,20 @@ package org.apache.spark.sql

import java.util.Locale

import org.apache.spark.{SparkFunSuite, TaskContext}
import scala.concurrent.Future

import org.apache.spark.{MapOutputStatistics, SparkFunSuite, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, UnresolvedHint}
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE
Expand Down Expand Up @@ -145,33 +150,56 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
}
}

test("inject columnar") {
test("inject columnar AQE on") {
testInjectColumnar(true)
}

test("inject columnar AQE off") {
testInjectColumnar(false)
}

private def testInjectColumnar(adaptiveEnabled: Boolean) {

def collectPlanSteps(plan: SparkPlan): Seq[Int] = plan match {
case a: AdaptiveSparkPlanExec =>
assert(a.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true"))
collectPlanSteps(a.executedPlan)
case _ => plan.collect {
case _: ReplacedRowToColumnarExec => 1
case _: ColumnarProjectExec => 10
case _: ColumnarToRowExec => 100
case s: QueryStageExec => collectPlanSteps(s.plan).sum
case _: MyShuffleExchangeExec => 1000
case _: MyBroadcastExchangeExec => 10000
}
}

val extensions = create { extensions =>
extensions.injectColumnar(session =>
MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))
}
withSession(extensions) { session =>
// The ApplyColumnarRulesAndInsertTransitions rule is not applied when enable AQE
session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, false)
session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, adaptiveEnabled)
assert(session.sessionState.columnarRules.contains(
MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.sqlContext.implicits._
// repartitioning avoids having the add operation pushed up into the LocalTableScan
val data = Seq((100L), (200L), (300L)).toDF("vals").repartition(1)
val df = data.selectExpr("vals + 1")
// perform a join to inject a broadcast exchange
val left = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("l1", "l2")
val right = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("r1", "r2")
val data = left.join(right, $"l1" === $"r1")
// repartitioning avoids having the add operation pushed up into the LocalTableScan
.repartition(1)
val df = data.selectExpr("l2 + r2")
// execute the plan so that the final adaptive plan is available when AQE is on
df.collect()
// Verify that both pre and post processing of the plan worked.
val found = df.queryExecution.executedPlan.collect {
case rep: ReplacedRowToColumnarExec => 1
case proj: ColumnarProjectExec => 10
case c2r: ColumnarToRowExec => 100
}.sum
assert(found == 111)

val found = collectPlanSteps(df.queryExecution.executedPlan).sum
assert(found == 11121)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be nice to comment what 11121 equals in terms of the execs - MyBroadcastExchangeExec, etc..

// Verify that we get back the expected, wrong, result
val result = df.collect()
assert(result(0).getLong(0) == 102L) // Check that broken columnar Add was used.
assert(result(1).getLong(0) == 202L)
assert(result(2).getLong(0) == 302L)
assert(result(0).getLong(0) == 101L) // Check that broken columnar Add was used.
assert(result(1).getLong(0) == 201L)
assert(result(2).getLong(0) == 301L)
}
}

Expand Down Expand Up @@ -671,6 +699,15 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan =
try {
plan match {
case e: ShuffleExchangeExec =>
// note that this is not actually columnar but demonstrates that exchanges can
// be replaced, particularly when adaptive query is enabled
val replaced = e.withNewChildren(e.children.map(replaceWithColumnarPlan))
MyShuffleExchangeExec(replaced.asInstanceOf[ShuffleExchangeExec])
case e: BroadcastExchangeExec =>
// note that this is not actually columnar but demonstrates that exchanges can
// be replaced, particularly when adaptive query is enabled
new MyBroadcastExchangeExec(e.mode, e.child)
case plan: ProjectExec =>
new ColumnarProjectExec(plan.projectList.map((exp) =>
replaceWithColumnarExpression(exp).asInstanceOf[NamedExpression]),
Expand All @@ -689,6 +726,37 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = replaceWithColumnarPlan(plan)
}

/**
* Custom Exchange used in tests to demonstrate that shuffles can be replaced regardless of
* whether adaptive query is enabled.
*/
case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchange {
override def shuffleId: Int = delegate.shuffleId
override def getNumMappers: Int = delegate.getNumMappers
override def getNumReducers: Int = delegate.getNumReducers
override def canChangeNumPartitions: Boolean = delegate.canChangeNumPartitions
override def mapOutputStatisticsFuture: Future[MapOutputStatistics] =
delegate.mapOutputStatisticsFuture
override def child: SparkPlan = delegate.child
override protected def doExecute(): RDD[InternalRow] = delegate.execute()
}

/**
* Custom Exchange used in tests to demonstrate that broadcasts can be replaced regardless of
* whether adaptive query is enabled.
*
* Note that extending a Spark case class is not recommended, but this was the easiest way to
* implement these tests.
*/
class MyBroadcastExchangeExec(mode: BroadcastMode,
child: SparkPlan) extends BroadcastExchangeExec(mode, child) {
override def equals(o: Any): Boolean = o match {
case o: MyBroadcastExchangeExec => mode.equals(o.mode) && child.equals(o.child)
case _ => false
}
override def hashCode(): Int = mode.hashCode() + child.hashCode()
}

class ReplacedRowToColumnarExec(override val child: SparkPlan)
extends RowToColumnarExec(child) {

Expand Down