Skip to content

Commit 0a9223f

Browse files
committed
Coalesce partitions for repartition by key when AQE is enabled.
1 parent 338efee commit 0a9223f

File tree

4 files changed

+86
-28
lines changed

4 files changed

+86
-28
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._
2828
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning}
2929
import org.apache.spark.sql.catalyst.util.truncatedString
3030
import org.apache.spark.sql.connector.catalog.Identifier
31+
import org.apache.spark.sql.internal.SQLConf
3132
import org.apache.spark.sql.types._
3233
import org.apache.spark.util.random.RandomSampler
3334

@@ -953,16 +954,18 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
953954
}
954955

955956
/**
956-
* This method repartitions data using [[Expression]]s into `numPartitions`, and receives
957+
* This method repartitions data using [[Expression]]s into `optNumPartitions`, and receives
957958
* information about the number of partitions during execution. Used when a specific ordering or
958959
* distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like
959-
* `coalesce` and `repartition`.
960+
* `coalesce` and `repartition`. If no `optNumPartitions` is given, by default it partitions data
961+
* into `numShufflePartitions` defined in `SQLConf`, and could be coalesced by AQE.
960962
*/
961963
case class RepartitionByExpression(
962964
partitionExpressions: Seq[Expression],
963965
child: LogicalPlan,
964-
numPartitions: Int) extends RepartitionOperation {
966+
optNumPartitions: Option[Int]) extends RepartitionOperation {
965967

968+
val numPartitions = optNumPartitions.getOrElse(SQLConf.get.numShufflePartitions)
966969
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
967970

968971
val partitioning: Partitioning = {
@@ -990,6 +993,15 @@ case class RepartitionByExpression(
990993
override def shuffle: Boolean = true
991994
}
992995

996+
object RepartitionByExpression {
997+
def apply(
998+
partitionExpressions: Seq[Expression],
999+
child: LogicalPlan,
1000+
numPartitions: Int): RepartitionByExpression = {
1001+
RepartitionByExpression(partitionExpressions, child, Some(numPartitions))
1002+
}
1003+
}
1004+
9931005
/**
9941006
* A relation with one row. This is used in "SELECT ..." without a from clause.
9951007
*/

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2991,17 +2991,9 @@ class Dataset[T] private[sql](
29912991
Repartition(numPartitions, shuffle = true, logicalPlan)
29922992
}
29932993

2994-
/**
2995-
* Returns a new Dataset partitioned by the given partitioning expressions into
2996-
* `numPartitions`. The resulting Dataset is hash partitioned.
2997-
*
2998-
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
2999-
*
3000-
* @group typedrel
3001-
* @since 2.0.0
3002-
*/
3003-
@scala.annotation.varargs
3004-
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
2994+
private def repartitionByExpression(
2995+
numPartitions: Option[Int],
2996+
partitionExprs: Column*): Dataset[T] = {
30052997
// The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments.
30062998
// However, we don't want to complicate the semantics of this API method.
30072999
// Instead, let's give users a friendly error message, pointing them to the new method.
@@ -3015,6 +3007,20 @@ class Dataset[T] private[sql](
30153007
}
30163008
}
30173009

3010+
/**
3011+
* Returns a new Dataset partitioned by the given partitioning expressions into
3012+
* `numPartitions`. The resulting Dataset is hash partitioned.
3013+
*
3014+
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
3015+
*
3016+
* @group typedrel
3017+
* @since 2.0.0
3018+
*/
3019+
@scala.annotation.varargs
3020+
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
3021+
repartitionByExpression(Some(numPartitions), partitionExprs: _*)
3022+
}
3023+
30183024
/**
30193025
* Returns a new Dataset partitioned by the given partitioning expressions, using
30203026
* `spark.sql.shuffle.partitions` as number of partitions.
@@ -3027,7 +3033,20 @@ class Dataset[T] private[sql](
30273033
*/
30283034
@scala.annotation.varargs
30293035
def repartition(partitionExprs: Column*): Dataset[T] = {
3030-
repartition(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*)
3036+
repartitionByExpression(None, partitionExprs: _*)
3037+
}
3038+
3039+
private def repartitionByRange(
3040+
numPartitions: Option[Int],
3041+
partitionExprs: Column*): Dataset[T] = {
3042+
require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.")
3043+
val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match {
3044+
case expr: SortOrder => expr
3045+
case expr: Expression => SortOrder(expr, Ascending)
3046+
})
3047+
withTypedPlan {
3048+
RepartitionByExpression(sortOrder, logicalPlan, numPartitions)
3049+
}
30313050
}
30323051

30333052
/**
@@ -3049,14 +3068,7 @@ class Dataset[T] private[sql](
30493068
*/
30503069
@scala.annotation.varargs
30513070
def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
3052-
require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.")
3053-
val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match {
3054-
case expr: SortOrder => expr
3055-
case expr: Expression => SortOrder(expr, Ascending)
3056-
})
3057-
withTypedPlan {
3058-
RepartitionByExpression(sortOrder, logicalPlan, numPartitions)
3059-
}
3071+
repartitionByRange(Some(numPartitions), partitionExprs: _*)
30603072
}
30613073

30623074
/**
@@ -3078,7 +3090,7 @@ class Dataset[T] private[sql](
30783090
*/
30793091
@scala.annotation.varargs
30803092
def repartitionByRange(partitionExprs: Column*): Dataset[T] = {
3081-
repartitionByRange(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*)
3093+
repartitionByRange(None, partitionExprs: _*)
30823094
}
30833095

30843096
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,8 +685,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
685685
case r: logical.Range =>
686686
execution.RangeExec(r) :: Nil
687687
case r: logical.RepartitionByExpression =>
688+
val canChangeNumParts = r.optNumPartitions.isEmpty
688689
exchange.ShuffleExchangeExec(
689-
r.partitioning, planLater(r.child), canChangeNumPartitions = false) :: Nil
690+
r.partitioning, planLater(r.child), canChangeNumPartitions = canChangeNumParts) :: Nil
690691
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
691692
case r: LogicalRDD =>
692693
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
2828
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
2929
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan}
3030
import org.apache.spark.sql.execution.command.DataWritingCommandExec
31-
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec}
31+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec}
3232
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
3333
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
3434
import org.apache.spark.sql.functions._
@@ -1026,15 +1026,48 @@ class AdaptiveQueryExecSuite
10261026
Seq(true, false).foreach { enableAQE =>
10271027
withSQLConf(
10281028
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
1029+
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
10291030
SQLConf.SHUFFLE_PARTITIONS.key -> "6",
10301031
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "7") {
1031-
val partitionsNum = spark.range(10).repartition($"id").rdd.collectPartitions().length
1032+
val df = spark.range(10).repartition($"id")
1033+
val partitionsNum = df.rdd.collectPartitions().length
10321034
if (enableAQE) {
1033-
assert(partitionsNum === 7)
1035+
assert(partitionsNum < 6)
1036+
1037+
val plan = df.queryExecution.executedPlan
1038+
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
1039+
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
1040+
case s: ShuffleExchangeExec => s
1041+
}
1042+
assert(shuffle.size == 1)
1043+
assert(shuffle(0).outputPartitioning.numPartitions == 7)
10341044
} else {
10351045
assert(partitionsNum === 6)
10361046
}
10371047
}
10381048
}
10391049
}
1050+
1051+
test("SPARK-32056 coalesce partitions for repartition by expressions when AQE is enabled") {
1052+
Seq(true, false).foreach { enableAQE =>
1053+
withSQLConf(
1054+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
1055+
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
1056+
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "50",
1057+
SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
1058+
val partitionsNum1 = (1 to 10).toDF.repartition($"value")
1059+
.rdd.collectPartitions().length
1060+
1061+
val partitionsNum2 = (1 to 10).toDF.repartitionByRange($"value".asc)
1062+
.rdd.collectPartitions().length
1063+
if (enableAQE) {
1064+
assert(partitionsNum1 < 10)
1065+
assert(partitionsNum2 < 10)
1066+
} else {
1067+
assert(partitionsNum1 === 10)
1068+
assert(partitionsNum2 === 10)
1069+
}
1070+
}
1071+
}
1072+
}
10401073
}

0 commit comments

Comments
 (0)