Skip to content

Commit 093df98

Browse files
committed
Coalesce partitions for repartiotion hint and sql.
1 parent 20cd47e commit 093df98

File tree

4 files changed

+76
-29
lines changed

4 files changed

+76
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ object ResolveHints {
183183
val hintName = hint.name.toUpperCase(Locale.ROOT)
184184

185185
def createRepartitionByExpression(
186-
numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
186+
numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
187187
val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder])
188188
if (sortOrders.nonEmpty) throw new IllegalArgumentException(
189189
s"""Invalid partitionExprs specified: $sortOrders
@@ -208,11 +208,11 @@ object ResolveHints {
208208
throw new AnalysisException(s"$hintName Hint expects a partition number as a parameter")
209209

210210
case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle =>
211-
createRepartitionByExpression(numPartitions, param.tail)
211+
createRepartitionByExpression(Some(numPartitions), param.tail)
212212
case param @ Seq(numPartitions: Int, _*) if shuffle =>
213-
createRepartitionByExpression(numPartitions, param.tail)
213+
createRepartitionByExpression(Some(numPartitions), param.tail)
214214
case param @ Seq(_*) if shuffle =>
215-
createRepartitionByExpression(conf.numShufflePartitions, param)
215+
createRepartitionByExpression(None, param)
216216
}
217217
}
218218

@@ -224,7 +224,7 @@ object ResolveHints {
224224
val hintName = hint.name.toUpperCase(Locale.ROOT)
225225

226226
def createRepartitionByExpression(
227-
numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
227+
numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
228228
val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute])
229229
if (invalidParams.nonEmpty) {
230230
throw new AnalysisException(s"$hintName Hint parameter should include columns, but " +
@@ -239,11 +239,11 @@ object ResolveHints {
239239

240240
hint.parameters match {
241241
case param @ Seq(IntegerLiteral(numPartitions), _*) =>
242-
createRepartitionByExpression(numPartitions, param.tail)
242+
createRepartitionByExpression(Some(numPartitions), param.tail)
243243
case param @ Seq(numPartitions: Int, _*) =>
244-
createRepartitionByExpression(numPartitions, param.tail)
244+
createRepartitionByExpression(Some(numPartitions), param.tail)
245245
case param @ Seq(_*) =>
246-
createRepartitionByExpression(conf.numShufflePartitions, param)
246+
createRepartitionByExpression(None, param)
247247
}
248248
}
249249

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
746746
ctx: QueryOrganizationContext,
747747
expressions: Seq[Expression],
748748
query: LogicalPlan): LogicalPlan = {
749-
RepartitionByExpression(expressions, query, conf.numShufflePartitions)
749+
RepartitionByExpression(expressions, query, None)
750750
}
751751

752752
/**

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,20 +199,20 @@ class SparkSqlParserSuite extends AnalysisTest {
199199
assertEqual(s"$baseSql distribute by a, b",
200200
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
201201
basePlan,
202-
numPartitions = newConf.numShufflePartitions))
202+
None))
203203
assertEqual(s"$baseSql distribute by a sort by b",
204204
Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
205205
global = false,
206206
RepartitionByExpression(UnresolvedAttribute("a") :: Nil,
207207
basePlan,
208-
numPartitions = newConf.numShufflePartitions)))
208+
None)))
209209
assertEqual(s"$baseSql cluster by a, b",
210210
Sort(SortOrder(UnresolvedAttribute("a"), Ascending) ::
211211
SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
212212
global = false,
213213
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
214214
basePlan,
215-
numPartitions = newConf.numShufflePartitions)))
215+
None)))
216216
}
217217

218218
test("pipeline concatenation") {

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.net.URI
2323
import org.apache.log4j.Level
2424

2525
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
26-
import org.apache.spark.sql.{QueryTest, Row, SparkSession, Strategy}
26+
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
2727
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
2828
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
2929
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan}
@@ -130,6 +130,17 @@ class AdaptiveQueryExecSuite
130130
assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader))
131131
}
132132

133+
private def checkInitialPartitionNum(df: Dataset[_]): Unit = {
134+
// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
135+
val plan = df.queryExecution.executedPlan
136+
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
137+
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
138+
case s: ShuffleExchangeExec => s
139+
}
140+
assert(shuffle.size == 1)
141+
assert(shuffle(0).outputPartitioning.numPartitions == 10)
142+
}
143+
133144
test("Change merge join to broadcast join") {
134145
withSQLConf(
135146
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
@@ -1040,14 +1051,8 @@ class AdaptiveQueryExecSuite
10401051
assert(partitionsNum1 < 10)
10411052
assert(partitionsNum2 < 10)
10421053

1043-
// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
1044-
val plan = df1.queryExecution.executedPlan
1045-
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
1046-
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
1047-
case s: ShuffleExchangeExec => s
1048-
}
1049-
assert(shuffle.size == 1)
1050-
assert(shuffle(0).outputPartitioning.numPartitions == 10)
1054+
checkInitialPartitionNum(df1)
1055+
checkInitialPartitionNum(df2)
10511056
} else {
10521057
assert(partitionsNum1 === 10)
10531058
assert(partitionsNum2 === 10)
@@ -1081,14 +1086,8 @@ class AdaptiveQueryExecSuite
10811086
assert(partitionsNum1 < 10)
10821087
assert(partitionsNum2 < 10)
10831088

1084-
// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
1085-
val plan = df1.queryExecution.executedPlan
1086-
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
1087-
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
1088-
case s: ShuffleExchangeExec => s
1089-
}
1090-
assert(shuffle.size == 1)
1091-
assert(shuffle(0).outputPartitioning.numPartitions == 10)
1089+
checkInitialPartitionNum(df1)
1090+
checkInitialPartitionNum(df2)
10921091
} else {
10931092
assert(partitionsNum1 === 10)
10941093
assert(partitionsNum2 === 10)
@@ -1100,4 +1099,52 @@ class AdaptiveQueryExecSuite
11001099
}
11011100
}
11021101
}
1102+
1103+
test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") {
1104+
Seq(true, false).foreach { enableAQE =>
1105+
withTempView("test") {
1106+
withSQLConf(
1107+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
1108+
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
1109+
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10",
1110+
SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
1111+
1112+
spark.range(10).toDF.createTempView("test")
1113+
1114+
val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test")
1115+
val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test")
1116+
val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id")
1117+
val df4 = spark.sql("SELECT * from test CLUSTER BY id")
1118+
1119+
val partitionsNum1 = df1.rdd.collectPartitions().length
1120+
val partitionsNum2 = df2.rdd.collectPartitions().length
1121+
val partitionsNum3 = df3.rdd.collectPartitions().length
1122+
val partitionsNum4 = df4.rdd.collectPartitions().length
1123+
1124+
if (enableAQE) {
1125+
assert(partitionsNum1 < 10)
1126+
assert(partitionsNum2 < 10)
1127+
assert(partitionsNum3 < 10)
1128+
assert(partitionsNum4 < 10)
1129+
1130+
checkInitialPartitionNum(df1)
1131+
checkInitialPartitionNum(df2)
1132+
checkInitialPartitionNum(df3)
1133+
checkInitialPartitionNum(df4)
1134+
} else {
1135+
assert(partitionsNum1 === 10)
1136+
assert(partitionsNum2 === 10)
1137+
assert(partitionsNum3 === 10)
1138+
assert(partitionsNum4 === 10)
1139+
}
1140+
1141+
// Don't coalesce partitions if the number of partitions is specified.
1142+
val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test")
1143+
val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test")
1144+
assert(df5.rdd.collectPartitions().length == 10)
1145+
assert(df6.rdd.collectPartitions().length == 10)
1146+
}
1147+
}
1148+
}
1149+
}
11031150
}

0 commit comments

Comments
 (0)