diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala index d6e44b780d772..5559e19f97ccd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala @@ -92,9 +92,10 @@ object ShufflePartitionsUtil extends Logging { var coalescedSize = 0L var i = 0 - def createPartitionSpec(): Unit = { - // Skip empty inputs, as it is a waste to launch an empty task. - if (coalescedSize > 0) { + def createPartitionSpec(last: Boolean = false): Unit = { + // Skip empty inputs, as it is a waste to launch an empty task + // unless all inputs are empty + if (coalescedSize > 0 || (last && partitionSpecs.isEmpty)) { partitionSpecs += CoalescedPartitionSpec(latestSplitPoint, i) } } @@ -120,7 +121,7 @@ object ShufflePartitionsUtil extends Logging { } i += 1 } - createPartitionSpec() + createPartitionSpec(last = true) partitionSpecs } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala index f5c3b7816f5ea..c8cc4924dc800 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala @@ -200,7 +200,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite { val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) checkEstimation( Array(bytesByPartitionId1, bytesByPartitionId2), - Seq.empty, targetSize, minNumPartitions) + Array(CoalescedPartitionSpec(0, 5)), targetSize, minNumPartitions) } @@ -243,21 +243,23 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite { } } - test("do not create partition spec for 0-size partitions") { + test("do not create partition spec for 0-size partitions except all partitions are empty") { val targetSize = 100 val minNumPartitions = 2 + val expectedPartitionSpecs = Array(CoalescedPartitionSpec(0, 5)) { // 1 shuffle: All bytes per partition are 0, no partition spec created. val bytesByPartitionId = Array[Long](0, 0, 0, 0, 0) - checkEstimation(Array(bytesByPartitionId), Seq.empty, targetSize) + checkEstimation(Array(bytesByPartitionId), expectedPartitionSpecs, targetSize) } { // 2 shuffles: All bytes per partition are 0, no partition spec created. val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) - checkEstimation(Array(bytesByPartitionId1, bytesByPartitionId2), Seq.empty, targetSize) + checkEstimation(Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionSpecs, targetSize) } { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 9fa97bffa8910..461876413582e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -202,7 +202,7 @@ class AdaptiveQueryExecSuite } } - test("Empty stage coalesced to 0-partition RDD") { + test("Empty stage coalesced to 1-partition RDD") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true") { @@ -213,11 +213,12 @@ class AdaptiveQueryExecSuite checkAnswer(testDf, Seq()) val plan = testDf.queryExecution.executedPlan assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined) + assert(plan.execute().getNumPartitions == 1) val coalescedReaders = collect(plan) { case r: CustomShuffleReaderExec => r } - assert(coalescedReaders.length == 2) - coalescedReaders.foreach(r => assert(r.partitionSpecs.isEmpty)) + assert(coalescedReaders.length == 3, s"$plan") + coalescedReaders.foreach(r => assert(r.partitionSpecs.length == 1)) } withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { @@ -225,11 +226,18 @@ class AdaptiveQueryExecSuite checkAnswer(testDf, Seq()) val plan = testDf.queryExecution.executedPlan assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + assert(plan.execute().getNumPartitions == 1) val coalescedReaders = collect(plan) { case r: CustomShuffleReaderExec => r } - assert(coalescedReaders.length == 2, s"$plan") - coalescedReaders.foreach(r => assert(r.partitionSpecs.isEmpty)) + assert(coalescedReaders.length == 3, s"$plan") + coalescedReaders.foreach { r => + if (r.isLocalReader) { + assert(r.partitionSpecs.length == 2) + } else { + assert(r.partitionSpecs.length == 1) + } + } } } }