diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala index 78923433eaab9..1a85d5c02075b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala @@ -82,7 +82,12 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { // `ShuffleQueryStageExec` gives null mapOutputStatistics when the input RDD has 0 partitions, // we should skip it when calculating the `partitionStartIndices`. val validMetrics = shuffleMetrics.filter(_ != null) - if (validMetrics.nonEmpty) { + // We may have different pre-shuffle partition numbers, don't reduce shuffle partition number + // in that case. For example when we union fully aggregated data (data is arranged to a single + // partition) and a result of a SortMergeJoin (multiple partitions). + val distinctNumPreShufflePartitions = + validMetrics.map(stats => stats.bytesByPartitionId.length).distinct + if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1) { val partitionStartIndices = estimatePartitionStartIndices(validMetrics.toArray) // This transformation adds new nodes, so we must use `transformUp` here. plan.transformUp { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala index 35c33a7157d38..b5dbdd0b18b49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala @@ -587,4 +587,22 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA } withSparkSession(test, 200, None) } + + test("Union two datasets with different pre-shuffle partition number") { + val test: SparkSession => Unit = { spark: SparkSession => + val df1 = spark.range(3).join(spark.range(3), "id").toDF() + val df2 = spark.range(3).groupBy().sum() + + val resultDf = df1.union(df2) + + checkAnswer(resultDf, Seq((0), (1), (2), (3)).map(i => Row(i))) + + val finalPlan = resultDf.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + // As the pre-shuffle partition number are different, we will skip reducing + // the shuffle partition numbers. + assert(finalPlan.collect { case p: CoalescedShuffleReaderExec => p }.length == 0) + } + withSparkSession(test, 100, None) + } }