diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala index 84c65df31a7c5..afa96e9ec4d56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan} import org.apache.spark.sql.internal.SQLConf /** @@ -34,7 +34,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl return plan } if (!plan.collectLeaves().forall(_.isInstanceOf[QueryStageExec]) - || plan.find(_.isInstanceOf[CustomShuffleReaderExec]).isDefined) { + || plan.find(_.isInstanceOf[CustomShuffleReaderExec]).isDefined) { // If not all leaf nodes are query stages, it's not safe to reduce the number of // shuffle partitions, because we may break the assumption that all children of a spark plan // have same number of output partitions. @@ -56,20 +56,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl // we should skip it when calculating the `partitionStartIndices`. val validMetrics = shuffleStages.flatMap(_.mapStats) - // We may have different pre-shuffle partition numbers, don't reduce shuffle partition number - // in that case. For example when we union fully aggregated data (data is arranged to a single - // partition) and a result of a SortMergeJoin (multiple partitions). - val distinctNumPreShufflePartitions = - validMetrics.map(stats => stats.bytesByPartitionId.length).distinct - if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1) { - // We fall back to Spark default parallelism if the minimum number of coalesced partitions - // is not set, so to avoid perf regressions compared to no coalescing. - val minPartitionNum = conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM) - .getOrElse(session.sparkContext.defaultParallelism) - val partitionSpecs = ShufflePartitionsUtil.coalescePartitions( - validMetrics.toArray, - advisoryTargetSize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES), - minNumPartitions = minPartitionNum) + def updatePlan(partitionSpecs: Seq[ShufflePartitionSpec]): SparkPlan = { // This transformation adds new nodes, so we must use `transformUp` here. val stageIds = shuffleStages.map(_.id).toSet plan.transformUp { @@ -79,8 +66,29 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) => CustomShuffleReaderExec(stage, partitionSpecs) } + } + + if (validMetrics.isEmpty) { + updatePlan(Nil) } else { - plan + // We may have different pre-shuffle partition numbers, don't reduce shuffle partition + // number in that case. For example when we union fully aggregated data (data is arranged + // to a single partition) and a result of a SortMergeJoin (multiple partitions). + val distinctNumPreShufflePartitions = + validMetrics.map(stats => stats.bytesByPartitionId.length).distinct + if (distinctNumPreShufflePartitions.length == 1) { + // We fall back to Spark default parallelism if the minimum number of coalesced partitions + // is not set, so to avoid perf regressions compared to no coalescing. + val minPartitionNum = conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM) + .getOrElse(session.sparkContext.defaultParallelism) + val partitionSpecs = ShufflePartitionsUtil.coalescePartitions( + validMetrics.toArray, + advisoryTargetSize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES), + minNumPartitions = minPartitionNum) + updatePlan(partitionSpecs) + } else { + plan + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 27d9748476c98..a274b931ca8c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -211,24 +211,26 @@ class AdaptiveQueryExecSuite withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val testDf = df1.where('a > 10).join(df2.where('b > 10), "id").groupBy('a).count() checkAnswer(testDf, Seq()) + assert(testDf.rdd.collectPartitions().length == 0) val plan = testDf.queryExecution.executedPlan assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined) val coalescedReaders = collect(plan) { case r: CustomShuffleReaderExec => r } - assert(coalescedReaders.length == 2) + assert(coalescedReaders.length == 3, s"$plan") coalescedReaders.foreach(r => assert(r.partitionSpecs.isEmpty)) } withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { val testDf = df1.where('a > 10).join(df2.where('b > 10), "id").groupBy('a).count() checkAnswer(testDf, Seq()) + assert(testDf.rdd.collectPartitions().length == 0) val plan = testDf.queryExecution.executedPlan assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) val coalescedReaders = collect(plan) { case r: CustomShuffleReaderExec => r } - assert(coalescedReaders.length == 2, s"$plan") + assert(coalescedReaders.length == 3, s"$plan") coalescedReaders.foreach(r => assert(r.partitionSpecs.isEmpty)) } }