diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 459311df22d2..9351b074c659 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -56,7 +56,6 @@ private[execution] object SparkPlanInfo { case ReusedSubqueryExec(child) => child :: Nil case a: AdaptiveSparkPlanExec => a.executedPlan :: Nil case stage: QueryStageExec => stage.plan :: Nil - case localReader: LocalShuffleReaderExec => localReader.child :: Nil case _ => plan.children ++ plan.subqueries } val metrics = plan.metrics.toSeq.map { case (key, metric) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala index 94e66b0c3a43..0ec8710e4db4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala @@ -125,7 +125,6 @@ trait AdaptiveSparkPlanHelper { private def allChildren(p: SparkPlan): Seq[SparkPlan] = p match { case a: AdaptiveSparkPlanExec => Seq(a.executedPlan) case s: QueryStageExec => Seq(s.plan) - case l: LocalShuffleReaderExec => Seq(l.child) case _ => p.children } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index 89e2813695a6..d8dd7224fef3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan} +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight} import org.apache.spark.sql.internal.SQLConf @@ -70,7 +70,7 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { } } -case class LocalShuffleReaderExec(child: QueryStageExec) extends LeafExecNode { +case class LocalShuffleReaderExec(child: QueryStageExec) extends UnaryExecNode { override def output: Seq[Attribute] = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala index 1a85d5c02075..5a505c213a26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ReusedQueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{LocalShuffleReaderExec, QueryStageExec, ReusedQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.ThreadUtils @@ -64,10 +64,14 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { return plan } - val shuffleStages = plan.collect { - case stage: ShuffleQueryStageExec => stage - case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) => stage + def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match { + case _: LocalShuffleReaderExec => Nil + case stage: ShuffleQueryStageExec => Seq(stage) + case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) => Seq(stage) + case _ => plan.children.flatMap(collectShuffleStages) } + + val shuffleStages = collectShuffleStages(plan) // ShuffleExchanges introduced by repartition do not support changing the number of partitions. // We change the number of partitions in the stage only if all the ShuffleExchanges support it. if (!shuffleStages.forall(_.plan.canChangeNumPartitions)) {