@@ -130,15 +130,15 @@ class AdaptiveQueryExecSuite
130130 assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader))
131131 }
132132
133- private def checkInitialPartitionNum (df : Dataset [_]): Unit = {
133+ private def checkInitialPartitionNum (df : Dataset [_], numPartition : Int ): Unit = {
134134 // repartition obeys initialPartitionNum when adaptiveExecutionEnabled
135135 val plan = df.queryExecution.executedPlan
136136 assert(plan.isInstanceOf [AdaptiveSparkPlanExec ])
137137 val shuffle = plan.asInstanceOf [AdaptiveSparkPlanExec ].executedPlan.collect {
138138 case s : ShuffleExchangeExec => s
139139 }
140140 assert(shuffle.size == 1 )
141- assert(shuffle(0 ).outputPartitioning.numPartitions == 10 )
141+ assert(shuffle(0 ).outputPartitioning.numPartitions == numPartition )
142142 }
143143
144144 test(" Change merge join to broadcast join" ) {
@@ -1051,8 +1051,8 @@ class AdaptiveQueryExecSuite
10511051 assert(partitionsNum1 < 10 )
10521052 assert(partitionsNum2 < 10 )
10531053
1054- checkInitialPartitionNum(df1)
1055- checkInitialPartitionNum(df2)
1054+ checkInitialPartitionNum(df1, 10 )
1055+ checkInitialPartitionNum(df2, 10 )
10561056 } else {
10571057 assert(partitionsNum1 === 10 )
10581058 assert(partitionsNum2 === 10 )
@@ -1086,8 +1086,8 @@ class AdaptiveQueryExecSuite
10861086 assert(partitionsNum1 < 10 )
10871087 assert(partitionsNum2 < 10 )
10881088
1089- checkInitialPartitionNum(df1)
1090- checkInitialPartitionNum(df2)
1089+ checkInitialPartitionNum(df1, 10 )
1090+ checkInitialPartitionNum(df2, 10 )
10911091 } else {
10921092 assert(partitionsNum1 === 10 )
10931093 assert(partitionsNum2 === 10 )
@@ -1127,10 +1127,10 @@ class AdaptiveQueryExecSuite
11271127 assert(partitionsNum3 < 10 )
11281128 assert(partitionsNum4 < 10 )
11291129
1130- checkInitialPartitionNum(df1)
1131- checkInitialPartitionNum(df2)
1132- checkInitialPartitionNum(df3)
1133- checkInitialPartitionNum(df4)
1130+ checkInitialPartitionNum(df1, 10 )
1131+ checkInitialPartitionNum(df2, 10 )
1132+ checkInitialPartitionNum(df3, 10 )
1133+ checkInitialPartitionNum(df4, 10 )
11341134 } else {
11351135 assert(partitionsNum1 === 10 )
11361136 assert(partitionsNum2 === 10 )
0 commit comments