From 9d6661c829a4a82aae64ed0522c44e4c3d8f4f0b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 21 May 2017 13:07:25 -0700 Subject: [PATCH] [SPARK-20792][SS] Support same timeout operations in mapGroupsWithState function in batch queries as in streaming queries ## What changes were proposed in this pull request? Currently, in the batch queries, timeout is disabled (i.e. GroupStateTimeout.NoTimeout) which means any GroupState.setTimeout*** operation would throw UnsupportedOperationException. This makes it weird when converting a streaming query into a batch query by changing the input DF from streaming to a batch DF. If the timeout was enabled and used, then the batch query will start throwing UnsupportedOperationException. This PR creates the dummy state in batch queries with the provided timeoutConf so that it behaves in the same way. The code has been refactored to make it obvious when the state is being created for a batch query or a streaming query. ## How was this patch tested? Additional tests Author: Tathagata Das Closes #18024 from tdas/SPARK-20792. --- .../spark/sql/execution/SparkStrategies.scala | 5 +- .../apache/spark/sql/execution/objects.scala | 6 +- .../FlatMapGroupsWithStateExec.scala | 2 +- .../execution/streaming/GroupStateImpl.scala | 42 +++---- .../FlatMapGroupsWithStateSuite.scala | 113 +++++++++++++----- 5 files changed, 116 insertions(+), 52 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ca2f6dd7a84b2..73541c22c6308 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -383,8 +383,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil case logical.FlatMapGroupsWithState( - f, key, value, grouping, data, output, _, _, _, _, child) => - execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil + f, key, value, grouping, data, output, _, _, _, timeout, child) => + execution.MapGroupsExec( + f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 48c7b80bffe03..34391818f3b9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.streaming.GroupStateTimeout import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -361,8 +362,11 @@ object MapGroupsExec { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, + timeoutConf: GroupStateTimeout, child: SparkPlan): MapGroupsExec = { - val f = (key: Any, values: Iterator[Any]) => func(key, values, new GroupStateImpl[Any](None)) + val f = (key: Any, values: Iterator[Any]) => { + func(key, values, GroupStateImpl.createForBatch(timeoutConf)) + } new MapGroupsExec(f, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index bd8d5d7b43d3a..3ceb4cf84a413 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -215,7 +215,7 @@ case class FlatMapGroupsWithStateExec( val keyObj = getKeyObj(keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects val stateObjOption = getStateObj(prevStateRowOption) - val keyedState = new GroupStateImpl( + val keyedState = GroupStateImpl.createForStreaming( stateObjOption, batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index d4606fd5a8463..4401e86936af9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -38,20 +38,13 @@ import org.apache.spark.unsafe.types.CalendarInterval * @param hasTimedOut Whether the key for which this state wrapped is being created is * getting timed out or not. */ -private[sql] class GroupStateImpl[S]( +private[sql] class GroupStateImpl[S] private( optionalValue: Option[S], batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, timeoutConf: GroupStateTimeout, override val hasTimedOut: Boolean) extends GroupState[S] { - // Constructor to create dummy state when using mapGroupsWithState in a batch query - def this(optionalValue: Option[S]) = this( - optionalValue, - batchProcessingTimeMs = NO_TIMESTAMP, - eventTimeWatermarkMs = NO_TIMESTAMP, - timeoutConf = GroupStateTimeout.NoTimeout, - hasTimedOut = false) private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined private var updated: Boolean = false // whether value has been updated (but not removed) @@ -102,12 +95,7 @@ private[sql] class GroupStateImpl[S]( if (durationMs <= 0) { throw new IllegalArgumentException("Timeout duration must be positive") } - if (batchProcessingTimeMs != NO_TIMESTAMP) { - timeoutTimestamp = durationMs + batchProcessingTimeMs - } else { - // This is being called in a batch query, hence no processing timestamp. - // Just ignore any attempts to set timeout. - } + timeoutTimestamp = durationMs + batchProcessingTimeMs } override def setTimeoutDuration(duration: String): Unit = { @@ -128,12 +116,7 @@ private[sql] class GroupStateImpl[S]( s"Timeout timestamp ($timestampMs) cannot be earlier than the " + s"current watermark ($eventTimeWatermarkMs)") } - if (batchProcessingTimeMs != NO_TIMESTAMP) { - timeoutTimestamp = timestampMs - } else { - // This is being called in a batch query, hence no processing timestamp. - // Just ignore any attempts to set timeout. - } + timeoutTimestamp = timestampMs } @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") @@ -213,4 +196,23 @@ private[sql] class GroupStateImpl[S]( private[sql] object GroupStateImpl { // Value used represent the lack of valid timestamp as a long val NO_TIMESTAMP = -1L + + def createForStreaming[S]( + optionalValue: Option[S], + batchProcessingTimeMs: Long, + eventTimeWatermarkMs: Long, + timeoutConf: GroupStateTimeout, + hasTimedOut: Boolean): GroupStateImpl[S] = { + new GroupStateImpl[S]( + optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, timeoutConf, hasTimedOut) + } + + def createForBatch(timeoutConf: GroupStateTimeout): GroupStateImpl[Any] = { + new GroupStateImpl[Any]( + optionalValue = None, + batchProcessingTimeMs = NO_TIMESTAMP, + eventTimeWatermarkMs = NO_TIMESTAMP, + timeoutConf, + hasTimedOut = false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 10e91740eb922..6bb9408ce99ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -73,14 +73,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(state.hasRemoved === shouldBeRemoved) } + // === Tests for state in streaming queries === // Updating empty state - state = new GroupStateImpl[String](None) + state = GroupStateImpl.createForStreaming(None, 1, 1, NoTimeout, hasTimedOut = false) testState(None) state.update("") testState(Some(""), shouldBeUpdated = true) // Updating exiting state - state = new GroupStateImpl[String](Some("2")) + state = GroupStateImpl.createForStreaming(Some("2"), 1, 1, NoTimeout, hasTimedOut = false) testState(Some("2")) state.update("3") testState(Some("3"), shouldBeUpdated = true) @@ -99,25 +100,34 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } test("GroupState - setTimeout**** with NoTimeout") { - for (initState <- Seq(None, Some(5))) { - // for different initial state - implicit val state = new GroupStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false) - testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + for (initValue <- Seq(None, Some(5))) { + val states = Seq( + GroupStateImpl.createForStreaming(initValue, 1000, 1000, NoTimeout, hasTimedOut = false), + GroupStateImpl.createForBatch(NoTimeout) + ) + for (state <- states) { + // for streaming queries + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + // for batch queries + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + } } } test("GroupState - setTimeout**** with ProcessingTimeTimeout") { - implicit var state: GroupStateImpl[Int] = null - - state = new GroupStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + // for streaming queries + var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( + None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) state.setTimeoutDuration(500) - assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state + assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) state.update(5) - assert(state.getTimeoutTimestamp === 1500) // does not change + assert(state.getTimeoutTimestamp === 1500) // does not change state.setTimeoutDuration(1000) assert(state.getTimeoutTimestamp === 2000) state.setTimeoutDuration("2 second") @@ -125,22 +135,38 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) state.remove() - assert(state.getTimeoutTimestamp === 3000) // does not change - state.setTimeoutDuration(500) // can still be set + assert(state.getTimeoutTimestamp === 3000) // does not change + state.setTimeoutDuration(500) // can still be set assert(state.getTimeoutTimestamp === 1500) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + // for batch queries + state = GroupStateImpl.createForBatch(ProcessingTimeTimeout).asInstanceOf[GroupStateImpl[Int]] + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + state.setTimeoutDuration(500) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.update(5) + state.setTimeoutDuration(1000) + state.setTimeoutDuration("2 second") + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.remove() + state.setTimeoutDuration(500) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) } test("GroupState - setTimeout**** with EventTimeTimeout") { - implicit val state = new GroupStateImpl[Int]( - None, 1000, 1000, EventTimeTimeout, hasTimedOut = false) + var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( + None, 1000, 1000, EventTimeTimeout, false) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) state.setTimeoutTimestamp(5000) - assert(state.getTimeoutTimestamp === 5000) // can be set without initializing state + assert(state.getTimeoutTimestamp === 5000) // can be set without initializing state state.update(5) - assert(state.getTimeoutTimestamp === 5000) // does not change + assert(state.getTimeoutTimestamp === 5000) // does not change state.setTimeoutTimestamp(10000) assert(state.getTimeoutTimestamp === 10000) state.setTimeoutTimestamp(new Date(20000)) @@ -150,7 +176,22 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf state.remove() assert(state.getTimeoutTimestamp === 20000) state.setTimeoutTimestamp(5000) - assert(state.getTimeoutTimestamp === 5000) // can be set after removing state + assert(state.getTimeoutTimestamp === 5000) // can be set after removing state + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + // for batch queries + state = GroupStateImpl.createForBatch(EventTimeTimeout).asInstanceOf[GroupStateImpl[Int]] + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + state.setTimeoutTimestamp(5000) + + state.update(5) + state.setTimeoutTimestamp(10000) + state.setTimeoutTimestamp(new Date(20000)) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + state.remove() + state.setTimeoutTimestamp(5000) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) } @@ -165,7 +206,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(state.getTimeoutTimestamp === NO_TIMESTAMP) } - state = new GroupStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + state = GroupStateImpl.createForStreaming( + Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) testIllegalTimeout { state.setTimeoutDuration(-1000) } @@ -182,7 +224,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf state.setTimeoutDuration("1 month -1 day") } - state = new GroupStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) + state = GroupStateImpl.createForStreaming( + Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) testIllegalTimeout { state.setTimeoutTimestamp(-10000) } @@ -211,23 +254,32 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("GroupState - hasTimedOut") { for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) { + // for streaming queries for (initState <- Seq(None, Some(5))) { - val state1 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false) + val state1 = GroupStateImpl.createForStreaming( + initState, 1000, 1000, timeoutConf, hasTimedOut = false) assert(state1.hasTimedOut === false) - val state2 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true) + + val state2 = GroupStateImpl.createForStreaming( + initState, 1000, 1000, timeoutConf, hasTimedOut = true) assert(state2.hasTimedOut === true) } + + // for batch queries + assert(GroupStateImpl.createForBatch(timeoutConf).hasTimedOut === false) } } test("GroupState - primitive type") { - var intState = new GroupStateImpl[Int](None) + var intState = GroupStateImpl.createForStreaming[Int]( + None, 1000, 1000, NoTimeout, hasTimedOut = false) intercept[NoSuchElementException] { intState.get } assert(intState.getOption === None) - intState = new GroupStateImpl[Int](Some(10)) + intState = GroupStateImpl.createForStreaming[Int]( + Some(10), 1000, 1000, NoTimeout, hasTimedOut = false) assert(intState.get == 10) intState.update(0) assert(intState.get == 0) @@ -243,7 +295,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val beforeTimeoutThreshold = 999 val afterTimeoutThreshold = 1001 - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" @@ -748,15 +799,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } test("mapGroupsWithState - batch") { - val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + // Test the following + // - no initial state + // - timeouts operations work, does not throw any error [SPARK-20792] + // - works with primitive state type + val stateFunc = (key: String, values: Iterator[String], state: GroupState[Int]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") + state.setTimeoutTimestamp(0, "1 hour") + state.update(10) (key, values.size) } checkAnswer( spark.createDataset(Seq("a", "a", "b")) .groupByKey(x => x) - .mapGroupsWithState(stateFunc) + .mapGroupsWithState(EventTimeTimeout)(stateFunc) .toDF, spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) }