Skip to content

Commit 02708f9

Browse files
cloud-fanMatt Hawes
authored andcommitted
[SPARK-33494][SQL][AQE] Do not use local shuffle reader for repartition
This PR updates `ShuffleExchangeExec` to carry more information about how much we can change the partitioning. For `repartition(col)`, we should preserve the user-specified partitioning and don't apply the AQE local shuffle reader. Similar to `repartition(number, col)`, we should respect the user-specified partitioning. No a new test Closes apache#30432 from cloud-fan/aqe. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent c098917 commit 02708f9

File tree

8 files changed

+83
-44
lines changed

8 files changed

+83
-44
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
3232
import org.apache.spark.sql.execution.aggregate.AggUtils
3333
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
3434
import org.apache.spark.sql.execution.command._
35-
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
35+
import org.apache.spark.sql.execution.exchange.{REPARTITION, REPARTITION_WITH_NUM, ShuffleExchangeExec}
3636
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
3737
import org.apache.spark.sql.execution.python._
3838
import org.apache.spark.sql.execution.streaming._
@@ -754,7 +754,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
754754
case logical.Repartition(numPartitions, shuffle, child) =>
755755
if (shuffle) {
756756
ShuffleExchangeExec(RoundRobinPartitioning(numPartitions),
757-
planLater(child), noUserSpecifiedNumPartition = false) :: Nil
757+
planLater(child), REPARTITION_WITH_NUM) :: Nil
758758
} else {
759759
execution.CoalesceExec(numPartitions, planLater(child)) :: Nil
760760
}
@@ -787,9 +787,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
787787
case r: logical.Range =>
788788
execution.RangeExec(r) :: Nil
789789
case r: logical.RepartitionByExpression =>
790-
val canChangeNumParts = r.optNumPartitions.isEmpty
791-
exchange.ShuffleExchangeExec(
792-
r.partitioning, planLater(r.child), canChangeNumParts) :: Nil
790+
val shuffleOrigin = if (r.optNumPartitions.isEmpty) {
791+
REPARTITION
792+
} else {
793+
REPARTITION_WITH_NUM
794+
}
795+
exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child), shuffleOrigin) :: Nil
793796
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
794797
case r: LogicalRDD =>
795798
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
package org.apache.spark.sql.execution.adaptive
1919

2020
import org.apache.spark.sql.SparkSession
21+
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
2122
import org.apache.spark.sql.catalyst.rules.Rule
2223
import org.apache.spark.sql.execution.SparkPlan
24+
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike}
2325
import org.apache.spark.sql.internal.SQLConf
2426

2527
/**
@@ -50,7 +52,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
5052
val shuffleStages = collectShuffleStages(plan)
5153
// ShuffleExchanges introduced by repartition do not support changing the number of partitions.
5254
// We change the number of partitions in the stage only if all the ShuffleExchanges support it.
53-
if (!shuffleStages.forall(_.shuffle.canChangeNumPartitions)) {
55+
if (!shuffleStages.forall(s => supportCoalesce(s.shuffle))) {
5456
plan
5557
} else {
5658
// `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions,
@@ -85,6 +87,11 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
8587
}
8688
}
8789
}
90+
91+
private def supportCoalesce(s: ShuffleExchangeLike): Boolean = {
92+
s.outputPartitioning != SinglePartition &&
93+
(s.shuffleOrigin == ENSURE_REQUIREMENTS || s.shuffleOrigin == REPARTITION)
94+
}
8895
}
8996

9097
object CoalesceShufflePartitions {

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.sql.execution.adaptive
1919

20+
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
2021
import org.apache.spark.sql.catalyst.rules.Rule
2122
import org.apache.spark.sql.execution._
22-
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
23+
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleExchangeLike}
2324
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, BuildSide}
2425
import org.apache.spark.sql.internal.SQLConf
2526

@@ -142,9 +143,13 @@ object OptimizeLocalShuffleReader {
142143

143144
def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match {
144145
case s: ShuffleQueryStageExec =>
145-
s.shuffle.canChangeNumPartitions && s.mapStats.isDefined
146-
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, _, _) =>
147-
s.shuffle.canChangeNumPartitions && s.mapStats.isDefined
146+
s.mapStats.isDefined && supportLocalReader(s.shuffle)
147+
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs, _) =>
148+
s.mapStats.isDefined && partitionSpecs.nonEmpty && supportLocalReader(s.shuffle)
148149
case _ => false
149150
}
151+
152+
private def supportLocalReader(s: ShuffleExchangeLike): Boolean = {
153+
s.outputPartitioning != SinglePartition && s.shuffleOrigin == ENSURE_REQUIREMENTS
154+
}
150155
}

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ trait ShuffleExchangeLike extends Exchange {
5757
def numPartitions: Int
5858

5959
/**
60-
* Returns whether the shuffle partition number can be changed.
60+
* The origin of this shuffle operator.
6161
*/
62-
def canChangeNumPartitions: Boolean
62+
def shuffleOrigin: ShuffleOrigin
6363

6464
/**
6565
* The asynchronous job that materializes the shuffle.
@@ -77,18 +77,30 @@ trait ShuffleExchangeLike extends Exchange {
7777
def runtimeStatistics: Statistics
7878
}
7979

80+
// Describes where the shuffle operator comes from.
81+
sealed trait ShuffleOrigin
82+
83+
// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule. It
84+
// means that the shuffle operator is used to ensure internal data partitioning requirements and
85+
// Spark is free to optimize it as long as the requirements are still ensured.
86+
case object ENSURE_REQUIREMENTS extends ShuffleOrigin
87+
88+
// Indicates that the shuffle operator was added by the user-specified repartition operator. Spark
89+
// can still optimize it via changing shuffle partition number, as data partitioning won't change.
90+
case object REPARTITION extends ShuffleOrigin
91+
92+
// Indicates that the shuffle operator was added by the user-specified repartition operator with
93+
// a certain partition number. Spark can't optimize it.
94+
case object REPARTITION_WITH_NUM extends ShuffleOrigin
95+
8096
/**
8197
* Performs a shuffle that will result in the desired partitioning.
8298
*/
8399
case class ShuffleExchangeExec(
84100
override val outputPartitioning: Partitioning,
85101
child: SparkPlan,
86-
noUserSpecifiedNumPartition: Boolean = true) extends ShuffleExchangeLike {
87-
88-
// If users specify the num partitions via APIs like `repartition`, we shouldn't change it.
89-
// For `SinglePartition`, it requires exactly one partition and we can't change it either.
90-
def canChangeNumPartitions: Boolean =
91-
noUserSpecifiedNumPartition && outputPartitioning != SinglePartition
102+
shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS)
103+
extends ShuffleExchangeLike {
92104

93105
private lazy val writeMetrics =
94106
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)

sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 23
2+
-- Number of queries: 24
33

44

55
-- !query
@@ -89,7 +89,7 @@ Results [2]: [key#x, max#x]
8989

9090
(5) Exchange
9191
Input [2]: [key#x, max#x]
92-
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
92+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
9393

9494
(6) HashAggregate
9595
Input [2]: [key#x, max#x]
@@ -100,7 +100,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x]
100100

101101
(7) Exchange
102102
Input [2]: [key#x, max(val)#x]
103-
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), true, [id=#x]
103+
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x]
104104

105105
(8) Sort
106106
Input [2]: [key#x, max(val)#x]
@@ -158,7 +158,7 @@ Results [2]: [key#x, max#x]
158158

159159
(5) Exchange
160160
Input [2]: [key#x, max#x]
161-
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
161+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
162162

163163
(6) HashAggregate
164164
Input [2]: [key#x, max#x]
@@ -245,7 +245,7 @@ Results [2]: [key#x, val#x]
245245

246246
(9) Exchange
247247
Input [2]: [key#x, val#x]
248-
Arguments: hashpartitioning(key#x, val#x, 4), true, [id=#x]
248+
Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x]
249249

250250
(10) HashAggregate
251251
Input [2]: [key#x, val#x]
@@ -613,7 +613,7 @@ Results [2]: [key#x, max#x]
613613

614614
(5) Exchange
615615
Input [2]: [key#x, max#x]
616-
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
616+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
617617

618618
(6) HashAggregate
619619
Input [2]: [key#x, max#x]
@@ -647,7 +647,7 @@ Results [2]: [key#x, max#x]
647647

648648
(11) Exchange
649649
Input [2]: [key#x, max#x]
650-
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
650+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
651651

652652
(12) HashAggregate
653653
Input [2]: [key#x, max#x]
@@ -730,7 +730,7 @@ Results [3]: [count#xL, sum#xL, count#xL]
730730

731731
(3) Exchange
732732
Input [3]: [count#xL, sum#xL, count#xL]
733-
Arguments: SinglePartition, true, [id=#x]
733+
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
734734

735735
(4) HashAggregate
736736
Input [3]: [count#xL, sum#xL, count#xL]
@@ -776,7 +776,7 @@ Results [2]: [key#x, buf#x]
776776

777777
(3) Exchange
778778
Input [2]: [key#x, buf#x]
779-
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
779+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
780780

781781
(4) ObjectHashAggregate
782782
Input [2]: [key#x, buf#x]
@@ -828,7 +828,7 @@ Results [2]: [key#x, min#x]
828828

829829
(4) Exchange
830830
Input [2]: [key#x, min#x]
831-
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
831+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
832832

833833
(5) Sort
834834
Input [2]: [key#x, min#x]

sql/core/src/test/resources/sql-tests/results/explain.sql.out

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 23
2+
-- Number of queries: 24
33

44

55
-- !query
@@ -92,7 +92,7 @@ Results [2]: [key#x, max#x]
9292

9393
(6) Exchange
9494
Input [2]: [key#x, max#x]
95-
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
95+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
9696

9797
(7) HashAggregate [codegen id : 2]
9898
Input [2]: [key#x, max#x]
@@ -103,7 +103,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x]
103103

104104
(8) Exchange
105105
Input [2]: [key#x, max(val)#x]
106-
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), true, [id=#x]
106+
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x]
107107

108108
(9) Sort [codegen id : 3]
109109
Input [2]: [key#x, max(val)#x]
@@ -160,7 +160,7 @@ Results [2]: [key#x, max#x]
160160

161161
(6) Exchange
162162
Input [2]: [key#x, max#x]
163-
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
163+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
164164

165165
(7) HashAggregate [codegen id : 2]
166166
Input [2]: [key#x, max#x]
@@ -250,7 +250,7 @@ Results [2]: [key#x, val#x]
250250

251251
(11) Exchange
252252
Input [2]: [key#x, val#x]
253-
Arguments: hashpartitioning(key#x, val#x, 4), true, [id=#x]
253+
Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x]
254254

255255
(12) HashAggregate [codegen id : 4]
256256
Input [2]: [key#x, val#x]
@@ -469,7 +469,7 @@ Results [1]: [max#x]
469469

470470
(10) Exchange
471471
Input [1]: [max#x]
472-
Arguments: SinglePartition, true, [id=#x]
472+
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
473473

474474
(11) HashAggregate [codegen id : 2]
475475
Input [1]: [max#x]
@@ -516,7 +516,7 @@ Results [1]: [max#x]
516516

517517
(17) Exchange
518518
Input [1]: [max#x]
519-
Arguments: SinglePartition, true, [id=#x]
519+
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
520520

521521
(18) HashAggregate [codegen id : 2]
522522
Input [1]: [max#x]
@@ -600,7 +600,7 @@ Results [1]: [max#x]
600600

601601
(9) Exchange
602602
Input [1]: [max#x]
603-
Arguments: SinglePartition, true, [id=#x]
603+
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
604604

605605
(10) HashAggregate [codegen id : 2]
606606
Input [1]: [max#x]
@@ -647,7 +647,7 @@ Results [2]: [sum#x, count#xL]
647647

648648
(16) Exchange
649649
Input [2]: [sum#x, count#xL]
650-
Arguments: SinglePartition, true, [id=#x]
650+
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
651651

652652
(17) HashAggregate [codegen id : 2]
653653
Input [2]: [sum#x, count#xL]
@@ -713,7 +713,7 @@ Results [2]: [sum#x, count#xL]
713713

714714
(7) Exchange
715715
Input [2]: [sum#x, count#xL]
716-
Arguments: SinglePartition, true, [id=#x]
716+
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
717717

718718
(8) HashAggregate [codegen id : 2]
719719
Input [2]: [sum#x, count#xL]
@@ -851,7 +851,7 @@ Results [2]: [key#x, max#x]
851851

852852
(6) Exchange
853853
Input [2]: [key#x, max#x]
854-
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
854+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
855855

856856
(7) HashAggregate [codegen id : 4]
857857
Input [2]: [key#x, max#x]
@@ -943,7 +943,7 @@ Results [3]: [count#xL, sum#xL, count#xL]
943943

944944
(4) Exchange
945945
Input [3]: [count#xL, sum#xL, count#xL]
946-
Arguments: SinglePartition, true, [id=#x]
946+
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
947947

948948
(5) HashAggregate [codegen id : 2]
949949
Input [3]: [count#xL, sum#xL, count#xL]
@@ -988,7 +988,7 @@ Results [2]: [key#x, buf#x]
988988

989989
(4) Exchange
990990
Input [2]: [key#x, buf#x]
991-
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
991+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
992992

993993
(5) ObjectHashAggregate
994994
Input [2]: [key#x, buf#x]
@@ -1039,7 +1039,7 @@ Results [2]: [key#x, min#x]
10391039

10401040
(5) Exchange
10411041
Input [2]: [key#x, min#x]
1042-
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
1042+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
10431043

10441044
(6) Sort [codegen id : 2]
10451045
Input [2]: [key#x, min#x]

sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
3333
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
3434
import org.apache.spark.sql.execution._
3535
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
36-
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike}
36+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
3737
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
3838
import org.apache.spark.sql.internal.SQLConf
3939
import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE
@@ -763,7 +763,9 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
763763
case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchangeLike {
764764
override def numMappers: Int = delegate.numMappers
765765
override def numPartitions: Int = delegate.numPartitions
766-
override def canChangeNumPartitions: Boolean = delegate.canChangeNumPartitions
766+
override def shuffleOrigin: ShuffleOrigin = {
767+
delegate.shuffleOrigin
768+
}
767769
override def mapOutputStatisticsFuture: Future[MapOutputStatistics] =
768770
delegate.mapOutputStatisticsFuture
769771
override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] =

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,4 +1013,14 @@ class AdaptiveQueryExecSuite
10131013
}
10141014
}
10151015
}
1016+
1017+
test("SPARK-33494: Do not use local shuffle reader for repartition") {
1018+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
1019+
val df = spark.table("testData").repartition('key)
1020+
df.collect()
1021+
// local shuffle reader breaks partitioning and shouldn't be used for repartition operation
1022+
// which is specified by users.
1023+
checkNumLocalShuffleReaders(df.queryExecution.executedPlan, numShufflesWithoutLocalReader = 1)
1024+
}
1025+
}
10161026
}

0 commit comments

Comments
 (0)