@@ -19,8 +19,7 @@ package org.apache.spark.sql.streaming
1919
2020import java .util .{Locale , TimeZone }
2121
22- import org .scalatest .Assertions
23- import org .scalatest .BeforeAndAfterAll
22+ import org .scalatest .{Assertions , BeforeAndAfterAll }
2423
2524import org .apache .spark .{SparkEnv , SparkException }
2625import org .apache .spark .rdd .BlockRDD
@@ -54,30 +53,35 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
5453
5554 import testImplicits ._
5655
57- val confAndTestNamePostfixMatrix = List (
58- (Seq (SQLConf .ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION .key -> " false" ), " " ),
59- (Seq (SQLConf .ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION .key -> " true" ),
60- " : enable remove redundant in stateful aggregation" )
61- )
56+ def executeFuncWithStateVersionSQLConf (
57+ stateVersion : Int ,
58+ confPairs : Seq [(String , String )],
59+ func : => Any ): Unit = {
60+ withSQLConf(confPairs ++
61+ Seq (SQLConf .STREAMING_AGGREGATION_STATE_FORMAT_VERSION .key -> stateVersion.toString): _* ) {
62+ func
63+ }
64+ }
6265
63- def testWithAggrOptions (testName : String , pairs : (String , String )* )(testFun : => Any ): Unit = {
64- confAndTestNamePostfixMatrix.foreach {
65- case (conf, testNamePostfix) => withSQLConf(pairs ++ conf : _* ) {
66- test(testName + testNamePostfix)(testFun)
66+ def testWithAllStateVersions (name : String , confPairs : (String , String )* )
67+ (func : => Any ): Unit = {
68+ for (version <- StatefulOperatorsHelper .supportedVersions) {
69+ test(s " $name - state format version $version" ) {
70+ executeFuncWithStateVersionSQLConf(version, confPairs, func)
6771 }
6872 }
6973 }
7074
71- def testQuietlyWithAggrOptions ( testName : String , pairs : (String , String )* )
72- ( testFun : => Any ): Unit = {
73- confAndTestNamePostfixMatrix.foreach {
74- case (conf, testNamePostfix) => withSQLConf(pairs ++ conf : _* ) {
75- testQuietly(testName + testNamePostfix)(testFun )
75+ def testQuietlyWithAllStateVersions ( name : String , confPairs : (String , String )* )
76+ ( func : => Any ): Unit = {
77+ for (version <- StatefulOperatorsHelper .supportedVersions) {
78+ testQuietly( s " $name - state format version $version " ) {
79+ executeFuncWithStateVersionSQLConf(version, confPairs, func )
7680 }
7781 }
7882 }
7983
80- testWithAggrOptions (" simple count, update mode" ) {
84+ testWithAllStateVersions (" simple count, update mode" ) {
8185 val inputData = MemoryStream [Int ]
8286
8387 val aggregated =
@@ -101,7 +105,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
101105 )
102106 }
103107
104- testWithAggrOptions (" count distinct" ) {
108+ testWithAllStateVersions (" count distinct" ) {
105109 val inputData = MemoryStream [(Int , Seq [Int ])]
106110
107111 val aggregated =
@@ -117,7 +121,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
117121 )
118122 }
119123
120- testWithAggrOptions (" simple count, complete mode" ) {
124+ testWithAllStateVersions (" simple count, complete mode" ) {
121125 val inputData = MemoryStream [Int ]
122126
123127 val aggregated =
@@ -140,7 +144,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
140144 )
141145 }
142146
143- testWithAggrOptions (" simple count, append mode" ) {
147+ testWithAllStateVersions (" simple count, append mode" ) {
144148 val inputData = MemoryStream [Int ]
145149
146150 val aggregated =
@@ -157,7 +161,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
157161 }
158162 }
159163
160- testWithAggrOptions (" sort after aggregate in complete mode" ) {
164+ testWithAllStateVersions (" sort after aggregate in complete mode" ) {
161165 val inputData = MemoryStream [Int ]
162166
163167 val aggregated =
@@ -182,7 +186,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
182186 )
183187 }
184188
185- testWithAggrOptions (" state metrics" ) {
189+ testWithAllStateVersions (" state metrics" ) {
186190 val inputData = MemoryStream [Int ]
187191
188192 val aggregated =
@@ -235,7 +239,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
235239 )
236240 }
237241
238- testWithAggrOptions (" multiple keys" ) {
242+ testWithAllStateVersions (" multiple keys" ) {
239243 val inputData = MemoryStream [Int ]
240244
241245 val aggregated =
@@ -252,7 +256,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
252256 )
253257 }
254258
255- testQuietlyWithAggrOptions (" midbatch failure" ) {
259+ testQuietlyWithAllStateVersions (" midbatch failure" ) {
256260 val inputData = MemoryStream [Int ]
257261 FailureSingleton .firstTime = true
258262 val aggregated =
@@ -278,7 +282,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
278282 )
279283 }
280284
281- testWithAggrOptions (" typed aggregators" ) {
285+ testWithAllStateVersions (" typed aggregators" ) {
282286 val inputData = MemoryStream [(String , Int )]
283287 val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2))
284288
@@ -288,7 +292,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
288292 )
289293 }
290294
291- testWithAggrOptions (" prune results by current_time, complete mode" ) {
295+ testWithAllStateVersions (" prune results by current_time, complete mode" ) {
292296 import testImplicits ._
293297 val clock = new StreamManualClock
294298 val inputData = MemoryStream [Long ]
@@ -340,7 +344,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
340344 )
341345 }
342346
343- testWithAggrOptions (" prune results by current_date, complete mode" ) {
347+ testWithAllStateVersions (" prune results by current_date, complete mode" ) {
344348 import testImplicits ._
345349 val clock = new StreamManualClock
346350 val tz = TimeZone .getDefault.getID
@@ -389,7 +393,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
389393 )
390394 }
391395
392- testWithAggrOptions (" SPARK-19690: do not convert batch aggregation in streaming query " +
396+ testWithAllStateVersions (" SPARK-19690: do not convert batch aggregation in streaming query " +
393397 " to streaming" ) {
394398 val streamInput = MemoryStream [Int ]
395399 val batchDF = Seq (1 , 2 , 3 , 4 , 5 )
@@ -454,7 +458,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
454458 true
455459 }
456460
457- testWithAggrOptions (" SPARK-21977: coalesce(1) with 0 partition RDD should be " +
461+ testWithAllStateVersions (" SPARK-21977: coalesce(1) with 0 partition RDD should be " +
458462 " repartitioned to 1" ) {
459463 val inputSource = new BlockRDDBackedSource (spark)
460464 MockSourceProvider .withMockSources(inputSource) {
@@ -493,8 +497,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
493497 }
494498 }
495499
496- testWithAggrOptions (" SPARK-21977: coalesce(1) with aggregation should still be repartitioned " +
497- " when it has non-empty grouping keys" ) {
500+ testWithAllStateVersions (" SPARK-21977: coalesce(1) with aggregation should still be " +
501+ " repartitioned when it has non-empty grouping keys" ) {
498502 val inputSource = new BlockRDDBackedSource (spark)
499503 MockSourceProvider .withMockSources(inputSource) {
500504 withTempDir { tempDir =>
@@ -546,7 +550,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
546550 }
547551 }
548552
549- testWithAggrOptions (" SPARK-22230: last should change with new batches" ) {
553+ testWithAllStateVersions (" SPARK-22230: last should change with new batches" ) {
550554 val input = MemoryStream [Int ]
551555
552556 val aggregated = input.toDF().agg(last(' value ))
@@ -562,7 +566,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
562566 )
563567 }
564568
565- testWithAggrOptions (" SPARK-23004: Ensure that TypedImperativeAggregate functions " +
569+ testWithAllStateVersions (" SPARK-23004: Ensure that TypedImperativeAggregate functions " +
566570 " do not throw errors" , SQLConf .SHUFFLE_PARTITIONS .key -> " 1" ) {
567571 // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error
568572 // by ensuring the following.
0 commit comments