@@ -23,7 +23,7 @@ import java.net.URI
2323import org .apache .log4j .Level
2424
2525import org .apache .spark .scheduler .{SparkListener , SparkListenerEvent , SparkListenerJobStart }
26- import org .apache .spark .sql .{QueryTest , Row , SparkSession , Strategy }
26+ import org .apache .spark .sql .{Dataset , QueryTest , Row , SparkSession , Strategy }
2727import org .apache .spark .sql .catalyst .optimizer .{BuildLeft , BuildRight }
2828import org .apache .spark .sql .catalyst .plans .logical .{Aggregate , LogicalPlan }
2929import org .apache .spark .sql .execution .{PartialReducerPartitionSpec , ReusedSubqueryExec , ShuffledRowRDD , SparkPlan }
@@ -130,6 +130,17 @@ class AdaptiveQueryExecSuite
130130 assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader))
131131 }
132132
133+ private def checkInitialPartitionNum (df : Dataset [_]): Unit = {
134+ // repartition obeys initialPartitionNum when adaptiveExecutionEnabled
135+ val plan = df.queryExecution.executedPlan
136+ assert(plan.isInstanceOf [AdaptiveSparkPlanExec ])
137+ val shuffle = plan.asInstanceOf [AdaptiveSparkPlanExec ].executedPlan.collect {
138+ case s : ShuffleExchangeExec => s
139+ }
140+ assert(shuffle.size == 1 )
141+ assert(shuffle(0 ).outputPartitioning.numPartitions == 10 )
142+ }
143+
133144 test(" Change merge join to broadcast join" ) {
134145 withSQLConf(
135146 SQLConf .ADAPTIVE_EXECUTION_ENABLED .key -> " true" ,
@@ -1040,14 +1051,8 @@ class AdaptiveQueryExecSuite
10401051 assert(partitionsNum1 < 10 )
10411052 assert(partitionsNum2 < 10 )
10421053
1043- // repartition obeys initialPartitionNum when adaptiveExecutionEnabled
1044- val plan = df1.queryExecution.executedPlan
1045- assert(plan.isInstanceOf [AdaptiveSparkPlanExec ])
1046- val shuffle = plan.asInstanceOf [AdaptiveSparkPlanExec ].executedPlan.collect {
1047- case s : ShuffleExchangeExec => s
1048- }
1049- assert(shuffle.size == 1 )
1050- assert(shuffle(0 ).outputPartitioning.numPartitions == 10 )
1054+ checkInitialPartitionNum(df1)
1055+ checkInitialPartitionNum(df2)
10511056 } else {
10521057 assert(partitionsNum1 === 10 )
10531058 assert(partitionsNum2 === 10 )
@@ -1081,14 +1086,8 @@ class AdaptiveQueryExecSuite
10811086 assert(partitionsNum1 < 10 )
10821087 assert(partitionsNum2 < 10 )
10831088
1084- // repartition obeys initialPartitionNum when adaptiveExecutionEnabled
1085- val plan = df1.queryExecution.executedPlan
1086- assert(plan.isInstanceOf [AdaptiveSparkPlanExec ])
1087- val shuffle = plan.asInstanceOf [AdaptiveSparkPlanExec ].executedPlan.collect {
1088- case s : ShuffleExchangeExec => s
1089- }
1090- assert(shuffle.size == 1 )
1091- assert(shuffle(0 ).outputPartitioning.numPartitions == 10 )
1089+ checkInitialPartitionNum(df1)
1090+ checkInitialPartitionNum(df2)
10921091 } else {
10931092 assert(partitionsNum1 === 10 )
10941093 assert(partitionsNum2 === 10 )
@@ -1100,4 +1099,52 @@ class AdaptiveQueryExecSuite
11001099 }
11011100 }
11021101 }
1102+
1103+ test(" SPARK-31220, SPARK-32056: repartition using sql and hint with AQE" ) {
1104+ Seq (true , false ).foreach { enableAQE =>
1105+ withTempView(" test" ) {
1106+ withSQLConf(
1107+ SQLConf .ADAPTIVE_EXECUTION_ENABLED .key -> enableAQE.toString,
1108+ SQLConf .COALESCE_PARTITIONS_ENABLED .key -> " true" ,
1109+ SQLConf .COALESCE_PARTITIONS_INITIAL_PARTITION_NUM .key -> " 10" ,
1110+ SQLConf .SHUFFLE_PARTITIONS .key -> " 10" ) {
1111+
1112+ spark.range(10 ).toDF.createTempView(" test" )
1113+
1114+ val df1 = spark.sql(" SELECT /*+ REPARTITION(id) */ * from test" )
1115+ val df2 = spark.sql(" SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test" )
1116+ val df3 = spark.sql(" SELECT * from test DISTRIBUTE BY id" )
1117+ val df4 = spark.sql(" SELECT * from test CLUSTER BY id" )
1118+
1119+ val partitionsNum1 = df1.rdd.collectPartitions().length
1120+ val partitionsNum2 = df2.rdd.collectPartitions().length
1121+ val partitionsNum3 = df3.rdd.collectPartitions().length
1122+ val partitionsNum4 = df4.rdd.collectPartitions().length
1123+
1124+ if (enableAQE) {
1125+ assert(partitionsNum1 < 10 )
1126+ assert(partitionsNum2 < 10 )
1127+ assert(partitionsNum3 < 10 )
1128+ assert(partitionsNum4 < 10 )
1129+
1130+ checkInitialPartitionNum(df1)
1131+ checkInitialPartitionNum(df2)
1132+ checkInitialPartitionNum(df3)
1133+ checkInitialPartitionNum(df4)
1134+ } else {
1135+ assert(partitionsNum1 === 10 )
1136+ assert(partitionsNum2 === 10 )
1137+ assert(partitionsNum3 === 10 )
1138+ assert(partitionsNum4 === 10 )
1139+ }
1140+
1141+ // Don't coalesce partitions if the number of partitions is specified.
1142+ val df5 = spark.sql(" SELECT /*+ REPARTITION(10, id) */ * from test" )
1143+ val df6 = spark.sql(" SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test" )
1144+ assert(df5.rdd.collectPartitions().length == 10 )
1145+ assert(df6.rdd.collectPartitions().length == 10 )
1146+ }
1147+ }
1148+ }
1149+ }
11031150}
0 commit comments