Skip to content

Commit 63dfb5d

Browse files
committed
Change the strategy: "add new option" -> "apply by default, but keep backward compatible"
1 parent 977428c commit 63dfb5d

File tree

8 files changed

+78
-63
lines changed

8 files changed

+78
-63
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -871,15 +871,15 @@ object SQLConf {
871871
.intConf
872872
.createWithDefault(2)
873873

874-
val ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION =
875-
buildConf("spark.sql.streaming.advanced.removeRedundantInStatefulAggregation")
874+
val STREAMING_AGGREGATION_STATE_FORMAT_VERSION =
875+
buildConf("spark.sql.streaming.streamingAggregation.stateFormatVersion")
876876
.internal()
877-
.doc("ADVANCED: When true, stateful aggregation tries to remove redundant data " +
878-
"between key and value in state. Enabling this option helps minimizing state size, " +
879-
"but no longer be compatible with state with disabling this option." +
880-
"You can't change this option after starting the query.")
881-
.booleanConf
882-
.createWithDefault(false)
877+
.doc("State format version used by streaming aggregation operations triggered " +
878+
"explicitly or implicitly via agg() in a streaming query. State between versions are " +
879+
"tend to be incompatible, so state format version shouldn't be modified after running.")
880+
.intConf
881+
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
882+
.createWithDefault(2)
883883

884884
val UNSUPPORTED_OPERATION_CHECK_ENABLED =
885885
buildConf("spark.sql.streaming.unsupportedOperationCheck")
@@ -1628,9 +1628,6 @@ class SQLConf extends Serializable with Logging {
16281628
def advancedPartitionPredicatePushdownEnabled: Boolean =
16291629
getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN)
16301630

1631-
def advancedRemoveRedundantInStatefulAggregation: Boolean =
1632-
getConf(ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION)
1633-
16341631
def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS)
16351632

16361633
def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,10 +328,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
328328
"Streaming aggregation doesn't support group aggregate pandas UDF")
329329
}
330330

331+
val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION)
332+
331333
aggregate.AggUtils.planStreamingAggregation(
332334
namedGroupingExpressions,
333335
aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]),
334336
rewrittenResultExpressions,
337+
stateVersion,
335338
planLater(child))
336339

337340
case _ => Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ object AggUtils {
256256
groupingExpressions: Seq[NamedExpression],
257257
functionsWithoutDistinct: Seq[AggregateExpression],
258258
resultExpressions: Seq[NamedExpression],
259+
stateFormatVersion: Int,
259260
child: SparkPlan): Seq[SparkPlan] = {
260261

261262
val groupingAttributes = groupingExpressions.map(_.toAttribute)
@@ -287,7 +288,8 @@ object AggUtils {
287288
child = partialAggregate)
288289
}
289290

290-
val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1)
291+
val restored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion,
292+
partialMerged1)
291293

292294
val partialMerged2: SparkPlan = {
293295
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
@@ -311,6 +313,7 @@ object AggUtils {
311313
stateInfo = None,
312314
outputMode = None,
313315
eventTimeWatermark = None,
316+
stateFormatVersion = stateFormatVersion,
314317
partialMerged2)
315318

316319
val finalAndCompleteAggregate: SparkPlan = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,21 @@ class IncrementalExecution(
100100
val state = new Rule[SparkPlan] {
101101

102102
override def apply(plan: SparkPlan): SparkPlan = plan transform {
103-
case StateStoreSaveExec(keys, None, None, None,
103+
case StateStoreSaveExec(keys, None, None, None, stateFormatVersion,
104104
UnaryExecNode(agg,
105-
StateStoreRestoreExec(_, None, child))) =>
105+
StateStoreRestoreExec(_, None, _, child))) =>
106106
val aggStateInfo = nextStatefulOperationStateInfo
107107
StateStoreSaveExec(
108108
keys,
109109
Some(aggStateInfo),
110110
Some(outputMode),
111111
Some(offsetSeqMetadata.batchWatermarkMs),
112+
stateFormatVersion,
112113
agg.withNewChildren(
113114
StateStoreRestoreExec(
114115
keys,
115116
Some(aggStateInfo),
117+
stateFormatVersion,
116118
child) :: Nil))
117119

118120
case StreamingDeduplicateExec(keys, child, None, None) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ object OffsetSeqMetadata extends Logging {
8989
private implicit val format = Serialization.formats(NoTypeHints)
9090
private val relevantSQLConfs = Seq(
9191
SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY,
92-
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION)
92+
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION)
9393

9494
/**
9595
* Default values of relevant configurations that are used for backward compatibility.
@@ -104,7 +104,9 @@ object OffsetSeqMetadata extends Logging {
104104
private val relevantSQLConfDefaultValues = Map[String, String](
105105
STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME,
106106
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key ->
107-
FlatMapGroupsWithStateExecHelper.legacyVersion.toString
107+
FlatMapGroupsWithStateExecHelper.legacyVersion.toString,
108+
STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
109+
StatefulOperatorsHelper.legacyVersion.toString
108110
)
109111

110112
def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
2323
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
2424
import org.apache.spark.sql.execution.streaming.state.{StateStore, UnsafeRowPair}
25-
import org.apache.spark.sql.internal.SQLConf
2625
import org.apache.spark.sql.types.StructType
2726

2827
object StatefulOperatorsHelper {
28+
29+
val supportedVersions = Seq(1, 2)
30+
val legacyVersion = 1
31+
2932
sealed trait StreamingAggregationStateManager extends Serializable {
3033
def extractKey(row: InternalRow): UnsafeRow
3134
def getValueExpressions: Seq[Attribute]
@@ -35,16 +38,14 @@ object StatefulOperatorsHelper {
3538
}
3639

3740
object StreamingAggregationStateManager extends Logging {
38-
def newImpl(
41+
def createStateManager(
3942
keyExpressions: Seq[Attribute],
4043
childOutput: Seq[Attribute],
41-
conf: SQLConf): StreamingAggregationStateManager = {
42-
43-
if (conf.advancedRemoveRedundantInStatefulAggregation) {
44-
log.info("Advanced option removeRedundantInStatefulAggregation activated!")
45-
new StreamingAggregationStateManagerImplV2(keyExpressions, childOutput)
46-
} else {
47-
new StreamingAggregationStateManagerImplV1(keyExpressions, childOutput)
44+
stateFormatVersion: Int): StreamingAggregationStateManager = {
45+
stateFormatVersion match {
46+
case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, childOutput)
47+
case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, childOutput)
48+
case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid")
4849
}
4950
}
5051
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,15 @@ object WatermarkSupport {
200200
case class StateStoreRestoreExec(
201201
keyExpressions: Seq[Attribute],
202202
stateInfo: Option[StatefulOperatorStateInfo],
203+
stateFormatVersion: Int,
203204
child: SparkPlan)
204205
extends UnaryExecNode with StateStoreReader {
205206

207+
private[sql] val stateManager = StreamingAggregationStateManager.createStateManager(
208+
keyExpressions, child.output, stateFormatVersion)
209+
206210
override protected def doExecute(): RDD[InternalRow] = {
207211
val numOutputRows = longMetric("numOutputRows")
208-
val stateManager = StreamingAggregationStateManager.newImpl(keyExpressions, child.output,
209-
sqlContext.conf)
210212

211213
child.execute().mapPartitionsWithStateStore(
212214
getStateInfo,
@@ -255,17 +257,18 @@ case class StateStoreSaveExec(
255257
stateInfo: Option[StatefulOperatorStateInfo] = None,
256258
outputMode: Option[OutputMode] = None,
257259
eventTimeWatermark: Option[Long] = None,
260+
stateFormatVersion: Int,
258261
child: SparkPlan)
259262
extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
260263

264+
private[sql] val stateManager = StreamingAggregationStateManager.createStateManager(
265+
keyExpressions, child.output, stateFormatVersion)
266+
261267
override protected def doExecute(): RDD[InternalRow] = {
262268
metrics // force lazy init at driver
263269
assert(outputMode.nonEmpty,
264270
"Incorrect planning in IncrementalExecution, outputMode has not been set")
265271

266-
val stateManager = StreamingAggregationStateManager.newImpl(keyExpressions, child.output,
267-
sqlContext.conf)
268-
269272
child.execute().mapPartitionsWithStateStore(
270273
getStateInfo,
271274
keyExpressions.toStructType,

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ package org.apache.spark.sql.streaming
1919

2020
import java.util.{Locale, TimeZone}
2121

22-
import org.scalatest.Assertions
23-
import org.scalatest.BeforeAndAfterAll
22+
import org.scalatest.{Assertions, BeforeAndAfterAll}
2423

2524
import org.apache.spark.{SparkEnv, SparkException}
2625
import org.apache.spark.rdd.BlockRDD
@@ -54,30 +53,35 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
5453

5554
import testImplicits._
5655

57-
val confAndTestNamePostfixMatrix = List(
58-
(Seq(SQLConf.ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION.key -> "false"), ""),
59-
(Seq(SQLConf.ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION.key -> "true"),
60-
" : enable remove redundant in stateful aggregation")
61-
)
56+
def executeFuncWithStateVersionSQLConf(
57+
stateVersion: Int,
58+
confPairs: Seq[(String, String)],
59+
func: => Any): Unit = {
60+
withSQLConf(confPairs ++
61+
Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString): _*) {
62+
func
63+
}
64+
}
6265

63-
def testWithAggrOptions(testName: String, pairs: (String, String)*)(testFun: => Any): Unit = {
64-
confAndTestNamePostfixMatrix.foreach {
65-
case (conf, testNamePostfix) => withSQLConf(pairs ++ conf: _*) {
66-
test(testName + testNamePostfix)(testFun)
66+
def testWithAllStateVersions(name: String, confPairs: (String, String)*)
67+
(func: => Any): Unit = {
68+
for (version <- StatefulOperatorsHelper.supportedVersions) {
69+
test(s"$name - state format version $version") {
70+
executeFuncWithStateVersionSQLConf(version, confPairs, func)
6771
}
6872
}
6973
}
7074

71-
def testQuietlyWithAggrOptions(testName: String, pairs: (String, String)*)
72-
(testFun: => Any): Unit = {
73-
confAndTestNamePostfixMatrix.foreach {
74-
case (conf, testNamePostfix) => withSQLConf(pairs ++ conf: _*) {
75-
testQuietly(testName + testNamePostfix)(testFun)
75+
def testQuietlyWithAllStateVersions(name: String, confPairs: (String, String)*)
76+
(func: => Any): Unit = {
77+
for (version <- StatefulOperatorsHelper.supportedVersions) {
78+
testQuietly(s"$name - state format version $version") {
79+
executeFuncWithStateVersionSQLConf(version, confPairs, func)
7680
}
7781
}
7882
}
7983

80-
testWithAggrOptions("simple count, update mode") {
84+
testWithAllStateVersions("simple count, update mode") {
8185
val inputData = MemoryStream[Int]
8286

8387
val aggregated =
@@ -101,7 +105,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
101105
)
102106
}
103107

104-
testWithAggrOptions("count distinct") {
108+
testWithAllStateVersions("count distinct") {
105109
val inputData = MemoryStream[(Int, Seq[Int])]
106110

107111
val aggregated =
@@ -117,7 +121,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
117121
)
118122
}
119123

120-
testWithAggrOptions("simple count, complete mode") {
124+
testWithAllStateVersions("simple count, complete mode") {
121125
val inputData = MemoryStream[Int]
122126

123127
val aggregated =
@@ -140,7 +144,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
140144
)
141145
}
142146

143-
testWithAggrOptions("simple count, append mode") {
147+
testWithAllStateVersions("simple count, append mode") {
144148
val inputData = MemoryStream[Int]
145149

146150
val aggregated =
@@ -157,7 +161,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
157161
}
158162
}
159163

160-
testWithAggrOptions("sort after aggregate in complete mode") {
164+
testWithAllStateVersions("sort after aggregate in complete mode") {
161165
val inputData = MemoryStream[Int]
162166

163167
val aggregated =
@@ -182,7 +186,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
182186
)
183187
}
184188

185-
testWithAggrOptions("state metrics") {
189+
testWithAllStateVersions("state metrics") {
186190
val inputData = MemoryStream[Int]
187191

188192
val aggregated =
@@ -235,7 +239,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
235239
)
236240
}
237241

238-
testWithAggrOptions("multiple keys") {
242+
testWithAllStateVersions("multiple keys") {
239243
val inputData = MemoryStream[Int]
240244

241245
val aggregated =
@@ -252,7 +256,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
252256
)
253257
}
254258

255-
testQuietlyWithAggrOptions("midbatch failure") {
259+
testQuietlyWithAllStateVersions("midbatch failure") {
256260
val inputData = MemoryStream[Int]
257261
FailureSingleton.firstTime = true
258262
val aggregated =
@@ -278,7 +282,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
278282
)
279283
}
280284

281-
testWithAggrOptions("typed aggregators") {
285+
testWithAllStateVersions("typed aggregators") {
282286
val inputData = MemoryStream[(String, Int)]
283287
val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2))
284288

@@ -288,7 +292,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
288292
)
289293
}
290294

291-
testWithAggrOptions("prune results by current_time, complete mode") {
295+
testWithAllStateVersions("prune results by current_time, complete mode") {
292296
import testImplicits._
293297
val clock = new StreamManualClock
294298
val inputData = MemoryStream[Long]
@@ -340,7 +344,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
340344
)
341345
}
342346

343-
testWithAggrOptions("prune results by current_date, complete mode") {
347+
testWithAllStateVersions("prune results by current_date, complete mode") {
344348
import testImplicits._
345349
val clock = new StreamManualClock
346350
val tz = TimeZone.getDefault.getID
@@ -389,7 +393,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
389393
)
390394
}
391395

392-
testWithAggrOptions("SPARK-19690: do not convert batch aggregation in streaming query " +
396+
testWithAllStateVersions("SPARK-19690: do not convert batch aggregation in streaming query " +
393397
"to streaming") {
394398
val streamInput = MemoryStream[Int]
395399
val batchDF = Seq(1, 2, 3, 4, 5)
@@ -454,7 +458,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
454458
true
455459
}
456460

457-
testWithAggrOptions("SPARK-21977: coalesce(1) with 0 partition RDD should be " +
461+
testWithAllStateVersions("SPARK-21977: coalesce(1) with 0 partition RDD should be " +
458462
"repartitioned to 1") {
459463
val inputSource = new BlockRDDBackedSource(spark)
460464
MockSourceProvider.withMockSources(inputSource) {
@@ -493,8 +497,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
493497
}
494498
}
495499

496-
testWithAggrOptions("SPARK-21977: coalesce(1) with aggregation should still be repartitioned " +
497-
"when it has non-empty grouping keys") {
500+
testWithAllStateVersions("SPARK-21977: coalesce(1) with aggregation should still be " +
501+
"repartitioned when it has non-empty grouping keys") {
498502
val inputSource = new BlockRDDBackedSource(spark)
499503
MockSourceProvider.withMockSources(inputSource) {
500504
withTempDir { tempDir =>
@@ -546,7 +550,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
546550
}
547551
}
548552

549-
testWithAggrOptions("SPARK-22230: last should change with new batches") {
553+
testWithAllStateVersions("SPARK-22230: last should change with new batches") {
550554
val input = MemoryStream[Int]
551555

552556
val aggregated = input.toDF().agg(last('value))
@@ -562,7 +566,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
562566
)
563567
}
564568

565-
testWithAggrOptions("SPARK-23004: Ensure that TypedImperativeAggregate functions " +
569+
testWithAllStateVersions("SPARK-23004: Ensure that TypedImperativeAggregate functions " +
566570
"do not throw errors", SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
567571
// See the JIRA SPARK-23004 for more details. In short, this test reproduces the error
568572
// by ensuring the following.

0 commit comments

Comments
 (0)