Skip to content

Commit 7b1c006

Browse files
viiryaLorenzo Martini
authored andcommitted
[SPARK-32056][SQL][FOLLOW-UP] Coalesce partitions for repartiotion hint and sql when AQE is enabled
As the followup of apache#28900, this patch extends coalescing partitions to repartitioning using hints and SQL syntax without specifying number of partitions, when AQE is enabled. When repartitionning using hints and SQL syntax, we should follow the shuffling behavior of repartition by expression/range to coalesce partitions when AQE is enabled. Yes. After this change, if users don't specify the number of partitions when repartitioning using `REPARTITION`/`REPARTITION_BY_RANGE` hint or `DISTRIBUTE BY`/`CLUSTER BY`, AQE will coalesce partitions. Unit tests. Closes apache#28952 from viirya/SPARK-32056-sql. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 8b3be59 commit 7b1c006

File tree

5 files changed

+78
-31
lines changed

5 files changed

+78
-31
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ object ResolveHints {
183183
val hintName = hint.name.toUpperCase(Locale.ROOT)
184184

185185
def createRepartitionByExpression(
186-
numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
186+
numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
187187
val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder])
188188
if (sortOrders.nonEmpty) throw new IllegalArgumentException(
189189
s"""Invalid partitionExprs specified: $sortOrders
@@ -208,11 +208,11 @@ object ResolveHints {
208208
throw new AnalysisException(s"$hintName Hint expects a partition number as a parameter")
209209

210210
case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle =>
211-
createRepartitionByExpression(numPartitions, param.tail)
211+
createRepartitionByExpression(Some(numPartitions), param.tail)
212212
case param @ Seq(numPartitions: Int, _*) if shuffle =>
213-
createRepartitionByExpression(numPartitions, param.tail)
213+
createRepartitionByExpression(Some(numPartitions), param.tail)
214214
case param @ Seq(_*) if shuffle =>
215-
createRepartitionByExpression(conf.numShufflePartitions, param)
215+
createRepartitionByExpression(None, param)
216216
}
217217
}
218218

@@ -224,7 +224,7 @@ object ResolveHints {
224224
val hintName = hint.name.toUpperCase(Locale.ROOT)
225225

226226
def createRepartitionByExpression(
227-
numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
227+
numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
228228
val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute])
229229
if (invalidParams.nonEmpty) {
230230
throw new AnalysisException(s"$hintName Hint parameter should include columns, but " +
@@ -239,11 +239,11 @@ object ResolveHints {
239239

240240
hint.parameters match {
241241
case param @ Seq(IntegerLiteral(numPartitions), _*) =>
242-
createRepartitionByExpression(numPartitions, param.tail)
242+
createRepartitionByExpression(Some(numPartitions), param.tail)
243243
case param @ Seq(numPartitions: Int, _*) =>
244-
createRepartitionByExpression(numPartitions, param.tail)
244+
createRepartitionByExpression(Some(numPartitions), param.tail)
245245
case param @ Seq(_*) =>
246-
createRepartitionByExpression(conf.numShufflePartitions, param)
246+
createRepartitionByExpression(None, param)
247247
}
248248
}
249249

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class ResolveHintsSuite extends AnalysisTest {
163163
checkAnalysis(
164164
UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")),
165165
RepartitionByExpression(
166-
Seq(AttributeReference("a", IntegerType)()), testRelation, conf.numShufflePartitions))
166+
Seq(AttributeReference("a", IntegerType)()), testRelation, None))
167167

168168
val e = intercept[IllegalArgumentException] {
169169
checkAnalysis(
@@ -187,7 +187,7 @@ class ResolveHintsSuite extends AnalysisTest {
187187
"REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("a")), table("TaBlE")),
188188
RepartitionByExpression(
189189
Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)),
190-
testRelation, conf.numShufflePartitions))
190+
testRelation, None))
191191

192192
val errMsg2 = "REPARTITION Hint parameter should include columns, but"
193193

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
740740
ctx: QueryOrganizationContext,
741741
expressions: Seq[Expression],
742742
query: LogicalPlan): LogicalPlan = {
743-
RepartitionByExpression(expressions, query, conf.numShufflePartitions)
743+
RepartitionByExpression(expressions, query, None)
744744
}
745745

746746
/**

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,20 +209,20 @@ class SparkSqlParserSuite extends AnalysisTest {
209209
assertEqual(s"$baseSql distribute by a, b",
210210
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
211211
basePlan,
212-
numPartitions = newConf.numShufflePartitions))
212+
None))
213213
assertEqual(s"$baseSql distribute by a sort by b",
214214
Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
215215
global = false,
216216
RepartitionByExpression(UnresolvedAttribute("a") :: Nil,
217217
basePlan,
218-
numPartitions = newConf.numShufflePartitions)))
218+
None)))
219219
assertEqual(s"$baseSql cluster by a, b",
220220
Sort(SortOrder(UnresolvedAttribute("a"), Ascending) ::
221221
SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
222222
global = false,
223223
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
224224
basePlan,
225-
numPartitions = newConf.numShufflePartitions)))
225+
None)))
226226
}
227227

228228
test("pipeline concatenation") {

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.net.URI
2323
import org.apache.log4j.Level
2424

2525
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
26-
import org.apache.spark.sql.{QueryTest, Row, SparkSession, Strategy}
26+
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
2727
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
2828
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan}
2929
import org.apache.spark.sql.execution.adaptive.OptimizeLocalShuffleReader.LOCAL_SHUFFLE_READER_DESCRIPTION
@@ -130,6 +130,17 @@ class AdaptiveQueryExecSuite
130130
assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader))
131131
}
132132

133+
private def checkInitialPartitionNum(df: Dataset[_], numPartition: Int): Unit = {
134+
// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
135+
val plan = df.queryExecution.executedPlan
136+
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
137+
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
138+
case s: ShuffleExchangeExec => s
139+
}
140+
assert(shuffle.size == 1)
141+
assert(shuffle(0).outputPartitioning.numPartitions == numPartition)
142+
}
143+
133144
test("Change merge join to broadcast join") {
134145
withSQLConf(
135146
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
@@ -892,14 +903,8 @@ class AdaptiveQueryExecSuite
892903
assert(partitionsNum1 < 10)
893904
assert(partitionsNum2 < 10)
894905

895-
// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
896-
val plan = df1.queryExecution.executedPlan
897-
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
898-
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
899-
case s: ShuffleExchangeExec => s
900-
}
901-
assert(shuffle.size == 1)
902-
assert(shuffle(0).outputPartitioning.numPartitions == 10)
906+
checkInitialPartitionNum(df1, 10)
907+
checkInitialPartitionNum(df2, 10)
903908
} else {
904909
assert(partitionsNum1 === 10)
905910
assert(partitionsNum2 === 10)
@@ -933,14 +938,8 @@ class AdaptiveQueryExecSuite
933938
assert(partitionsNum1 < 10)
934939
assert(partitionsNum2 < 10)
935940

936-
// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
937-
val plan = df1.queryExecution.executedPlan
938-
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
939-
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
940-
case s: ShuffleExchangeExec => s
941-
}
942-
assert(shuffle.size == 1)
943-
assert(shuffle(0).outputPartitioning.numPartitions == 10)
941+
checkInitialPartitionNum(df1, 10)
942+
checkInitialPartitionNum(df2, 10)
944943
} else {
945944
assert(partitionsNum1 === 10)
946945
assert(partitionsNum2 === 10)
@@ -966,4 +965,52 @@ class AdaptiveQueryExecSuite
966965
}
967966
}
968967
}
968+
969+
test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") {
970+
Seq(true, false).foreach { enableAQE =>
971+
withTempView("test") {
972+
withSQLConf(
973+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
974+
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
975+
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10",
976+
SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
977+
978+
spark.range(10).toDF.createTempView("test")
979+
980+
val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test")
981+
val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test")
982+
val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id")
983+
val df4 = spark.sql("SELECT * from test CLUSTER BY id")
984+
985+
val partitionsNum1 = df1.rdd.collectPartitions().length
986+
val partitionsNum2 = df2.rdd.collectPartitions().length
987+
val partitionsNum3 = df3.rdd.collectPartitions().length
988+
val partitionsNum4 = df4.rdd.collectPartitions().length
989+
990+
if (enableAQE) {
991+
assert(partitionsNum1 < 10)
992+
assert(partitionsNum2 < 10)
993+
assert(partitionsNum3 < 10)
994+
assert(partitionsNum4 < 10)
995+
996+
checkInitialPartitionNum(df1, 10)
997+
checkInitialPartitionNum(df2, 10)
998+
checkInitialPartitionNum(df3, 10)
999+
checkInitialPartitionNum(df4, 10)
1000+
} else {
1001+
assert(partitionsNum1 === 10)
1002+
assert(partitionsNum2 === 10)
1003+
assert(partitionsNum3 === 10)
1004+
assert(partitionsNum4 === 10)
1005+
}
1006+
1007+
// Don't coalesce partitions if the number of partitions is specified.
1008+
val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test")
1009+
val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test")
1010+
assert(df5.rdd.collectPartitions().length == 10)
1011+
assert(df6.rdd.collectPartitions().length == 10)
1012+
}
1013+
}
1014+
}
1015+
}
9691016
}

0 commit comments

Comments
 (0)