diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index f45e3560b2cf..f01947d8f5ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -92,6 +92,15 @@ case class AdaptiveSparkPlanExec( // optimizations should be stage-independent. @transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq( ReuseAdaptiveSubquery(conf, subqueryCache), + + // When adding local shuffle readers in 'OptimizeLocalShuffleReader`, we revert all the local + // readers if additional shuffles are introduced. This may be too conservative: maybe there is + // only one local reader that introduces shuffle, and we can still keep other local readers. + // Here we re-execute this rule with the sub-plan-tree of a query stage, to make sure necessary + // local readers are added before executing the query stage. + // This rule must be executed before `ReduceNumShufflePartitions`, as local shuffle readers + // can't change number of partitions. + OptimizeLocalShuffleReader(conf), ReduceNumShufflePartitions(conf), ApplyColumnarRulesAndInsertTransitions(session.sessionState.conf, session.sessionState.columnarRules), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 43802968c469..649467a27d93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -163,8 +163,9 @@ class AdaptiveQueryExecSuite assert(smj.size == 3) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 3) - // additional shuffle exchange introduced, only one shuffle reader to local shuffle reader. - checkNumLocalShuffleReaders(adaptivePlan, 1) + // The child of remaining one BroadcastHashJoin is not ShuffleQueryStage. + // So only two LocalShuffleReader. + checkNumLocalShuffleReaders(adaptivePlan, 2) } } @@ -188,7 +189,8 @@ class AdaptiveQueryExecSuite assert(smj.size == 3) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 3) - // additional shuffle exchange introduced, only one shuffle reader to local shuffle reader. + // The child of remaining two BroadcastHashJoin is not ShuffleQueryStage. + // So only two LocalShuffleReader. checkNumLocalShuffleReaders(adaptivePlan, 1) } } @@ -213,7 +215,8 @@ class AdaptiveQueryExecSuite assert(smj.size == 3) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 3) - // additional shuffle exchange introduced, only one shuffle reader to local shuffle reader. + // The child of remaining two BroadcastHashJoin is not ShuffleQueryStage. + // So only two LocalShuffleReader. checkNumLocalShuffleReaders(adaptivePlan, 1) } }