Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -102,6 +104,16 @@ case class AdaptiveSparkPlanExec(
OptimizeLocalShuffleReader(conf)
)

private def finalStageOptimizerRules: Seq[Rule[SparkPlan]] =
context.qe.sparkPlan match {
case _: DataWritingCommandExec | _: V2TableWriteExec =>
// SPARK-32932: Local shuffle reader could break partitioning that works best
// for the following writing command
queryStageOptimizerRules.filterNot(_.isInstanceOf[OptimizeLocalShuffleReader])
case _ =>
queryStageOptimizerRules
}

// A list of physical optimizer rules to be applied right after a new stage is created. The input
// plan to these rules has exchange as its root node.
@transient private val postStageCreationRules = Seq(
Expand Down Expand Up @@ -235,7 +247,7 @@ case class AdaptiveSparkPlanExec(
// Run the final plan when there's no more unfinished stages.
currentPhysicalPlan = applyPhysicalRules(
result.newPlan,
queryStageOptimizerRules ++ postStageCreationRules,
finalStageOptimizerRules ++ postStageCreationRules,
Some((planChangeLogger, "AQE Final Query Stage Optimization")))
isFinalPlan = true
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,19 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListe
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan}
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.util.Utils

class AdaptiveQueryExecSuite
Expand Down Expand Up @@ -1258,4 +1262,49 @@ class AdaptiveQueryExecSuite
}
}
}

test("SPARK-32932: Do not use local shuffle reader at final stage on write command") {
withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString,
SQLConf.SHUFFLE_PARTITIONS.key -> "5",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val data = for (
i <- 1L to 10L;
j <- 1L to 3L
) yield (i, j)

val df = data.toDF("i", "j").repartition($"j")
var noLocalReader: Boolean = false
val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
qe.executedPlan match {
case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) =>
assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec])
noLocalReader = collect(plan) {
case exec: CustomShuffleReaderExec if exec.isLocalReader => exec
}.isEmpty
case _ => // ignore other events
}
}
override def onFailure(funcName: String, qe: QueryExecution,
exception: Exception): Unit = {}
}
spark.listenerManager.register(listener)

withTable("t") {
df.write.partitionBy("j").saveAsTable("t")
sparkContext.listenerBus.waitUntilEmpty()
assert(noLocalReader)
noLocalReader = false
}

// Test DataSource v2
val format = classOf[NoopDataSource].getName
df.write.format(format).mode("overwrite").save()
sparkContext.listenerBus.waitUntilEmpty()
assert(noLocalReader)
noLocalReader = false

spark.listenerManager.unregister(listener)
}
}
}