Skip to content

Commit b9a1cd8

Browse files
wangyumcloud-fan
authored andcommitted
[SPARK-31220][SQL] repartition obeys initialPartitionNum when adaptiveExecutionEnabled
### What changes were proposed in this pull request? This PR makes `repartition`/`DISTRIBUTE BY` obeys [initialPartitionNum](https://github.com/apache/spark/blob/af4248b2d661d04fec89b37857a47713246d9465/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L446-L455) when adaptive execution enabled. ### Why are the changes needed? To make `DISTRIBUTE BY`/`GROUP BY` partitioned by same partition number. How to reproduce: ```scala spark.sql("CREATE TABLE spark_31220(id int)") spark.sql("set spark.sql.adaptive.enabled=true") spark.sql("set spark.sql.adaptive.coalescePartitions.initialPartitionNum=1000") ``` Before this PR: ``` scala> spark.sql("SELECT id from spark_31220 GROUP BY id").explain == Physical Plan == AdaptiveSparkPlan(isFinalPlan=false) +- HashAggregate(keys=[id#5], functions=[]) +- Exchange hashpartitioning(id#5, 1000), true, [id=#171] +- HashAggregate(keys=[id#5], functions=[]) +- FileScan parquet default.spark_31220[id#5] scala> spark.sql("SELECT id from spark_31220 DISTRIBUTE BY id").explain == Physical Plan == AdaptiveSparkPlan(isFinalPlan=false) +- Exchange hashpartitioning(id#5, 200), false, [id=#179] +- FileScan parquet default.spark_31220[id#5] ``` After this PR: ``` scala> spark.sql("SELECT id from spark_31220 GROUP BY id").explain == Physical Plan == AdaptiveSparkPlan(isFinalPlan=false) +- HashAggregate(keys=[id#5], functions=[]) +- Exchange hashpartitioning(id#5, 1000), true, [id=#171] +- HashAggregate(keys=[id#5], functions=[]) +- FileScan parquet default.spark_31220[id#5] scala> spark.sql("SELECT id from spark_31220 DISTRIBUTE BY id").explain == Physical Plan == AdaptiveSparkPlan(isFinalPlan=false) +- Exchange hashpartitioning(id#5, 1000), false, [id=#179] +- FileScan parquet default.spark_31220[id#5] ``` ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Unit test. Closes #27986 from wangyum/SPARK-31220. Authored-by: Yuming Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit 1d1eacd) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 4879822 commit b9a1cd8

File tree

3 files changed

+27
-12
lines changed

3 files changed

+27
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2765,7 +2765,15 @@ class SQLConf extends Serializable with Logging {
27652765

27662766
def cacheVectorizedReaderEnabled: Boolean = getConf(CACHE_VECTORIZED_READER_ENABLED)
27672767

2768-
def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS)
2768+
def defaultNumShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS)
2769+
2770+
def numShufflePartitions: Int = {
2771+
if (adaptiveExecutionEnabled && coalesceShufflePartitionsEnabled) {
2772+
getConf(COALESCE_PARTITIONS_INITIAL_PARTITION_NUM).getOrElse(defaultNumShufflePartitions)
2773+
} else {
2774+
defaultNumShufflePartitions
2775+
}
2776+
}
27692777

27702778
def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED)
27712779

@@ -2778,9 +2786,6 @@ class SQLConf extends Serializable with Logging {
27782786

27792787
def coalesceShufflePartitionsEnabled: Boolean = getConf(COALESCE_PARTITIONS_ENABLED)
27802788

2781-
def initialShufflePartitionNum: Int =
2782-
getConf(COALESCE_PARTITIONS_INITIAL_PARTITION_NUM).getOrElse(numShufflePartitions)
2783-
27842789
def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN)
27852790

27862791
def maxBatchesToRetainInMemory: Int = getConf(MAX_BATCHES_TO_RETAIN_IN_MEMORY)

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,6 @@ import org.apache.spark.sql.internal.SQLConf
3535
* the input partition ordering requirements are met.
3636
*/
3737
case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
38-
private def defaultNumPreShufflePartitions: Int =
39-
if (conf.adaptiveExecutionEnabled && conf.coalesceShufflePartitionsEnabled) {
40-
conf.initialShufflePartitionNum
41-
} else {
42-
conf.numShufflePartitions
43-
}
4438

4539
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
4640
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
@@ -57,7 +51,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
5751
BroadcastExchangeExec(mode, child)
5852
case (child, distribution) =>
5953
val numPartitions = distribution.requiredNumPartitions
60-
.getOrElse(defaultNumPreShufflePartitions)
54+
.getOrElse(conf.numShufflePartitions)
6155
ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)
6256
}
6357

@@ -95,7 +89,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
9589
// expected number of shuffle partitions. However, if it's smaller than
9690
// `conf.numShufflePartitions`, we pick `conf.numShufflePartitions` as the
9791
// expected number of shuffle partitions.
98-
math.max(nonShuffleChildrenNumPartitions.max, conf.numShufflePartitions)
92+
math.max(nonShuffleChildrenNumPartitions.max, conf.defaultNumShufflePartitions)
9993
} else {
10094
childrenNumPartitions.max
10195
}

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,4 +873,20 @@ class AdaptiveQueryExecSuite
873873
}
874874
}
875875
}
876+
877+
test("SPARK-31220 repartition obeys initialPartitionNum when adaptiveExecutionEnabled") {
878+
Seq(true, false).foreach { enableAQE =>
879+
withSQLConf(
880+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
881+
SQLConf.SHUFFLE_PARTITIONS.key -> "6",
882+
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "7") {
883+
val partitionsNum = spark.range(10).repartition($"id").rdd.collectPartitions().length
884+
if (enableAQE) {
885+
assert(partitionsNum === 7)
886+
} else {
887+
assert(partitionsNum === 6)
888+
}
889+
}
890+
}
891+
}
876892
}

0 commit comments

Comments
 (0)