From c1a0cd665319144ef116d6fe314ad09d0aa21323 Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Wed, 16 Jan 2019 17:40:12 -0800 Subject: [PATCH 1/4] Support multiple aggregates in append mode This patch proposes to add support for multiple aggregates in append mode. In append mode, the aggregates are emitted only after the watermark passes the threshold (e.g. the window boundary) and the emitted value is not affected by further late data. This allows to chain multiple aggregates in 'Append' output mode without worrying about retractions etc. However the current event time watermarks in structured streaming are tracked at a global level and this does not work when aggregates are chained. The downstream watermarks usually lags the ones before and the global (min or max) watermarks will not let the stages make progress independently. The patch tracks the watermarks at each (stateful) operator so that the aggregate outputs are generated when the watermark passes the thresholds at the corresponding stateful operator. The values are also saved into the commit/offset logs (similar to global watermark) Each aggregate should have a corresponding watermark defined while creating the query (E.g. via withWatermark) and this is used to track the progress of event time corresponding to the stateful operator. --- .../UnsupportedOperationChecker.scala | 28 +++--- .../analysis/UnsupportedOperationsSuite.scala | 34 +++++++- .../sql/execution/streaming/CommitLog.scala | 3 +- .../FlatMapGroupsWithStateExec.scala | 2 +- .../streaming/IncrementalExecution.scala | 29 ++++--- .../streaming/MicroBatchExecution.scala | 11 ++- .../sql/execution/streaming/OffsetSeq.scala | 9 +- .../execution/streaming/StreamExecution.scala | 6 +- .../StreamingSymmetricHashJoinExec.scala | 2 +- .../streaming/WatermarkTracker.scala | 87 ++++++++++++++++--- .../streaming/statefulOperators.scala | 11 ++- .../streaming/OffsetSeqLogSuite.scala | 8 +- .../streaming/EventTimeWatermarkSuite.scala | 72 +++++++++++++++ 13 files changed, 249 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 41ba6d34b5499..4ca1084762254 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -84,10 +84,10 @@ object UnsupportedOperationChecker { " or the output mode is not append on a streaming DataFrames/Datasets")(plan) } - // Disallow multiple streaming aggregations val aggregates = collectStreamingAggregates(plan) - if (aggregates.size > 1) { + // multiple aggregates are supported only in append mode + if (outputMode != InternalOutputModes.Append && aggregates.size > 1) { throwError( "Multiple streaming aggregations are not supported with " + "streaming DataFrames/Datasets")(plan) @@ -96,20 +96,20 @@ object UnsupportedOperationChecker { // Disallow some output mode outputMode match { case InternalOutputModes.Append if aggregates.nonEmpty => - val aggregate = aggregates.head - - // Find any attributes that are associated with an eventTime watermark. - val watermarkAttributes = aggregate.groupingExpressions.collect { - case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a - } + aggregates.foreach(aggregate => { + // Find any attributes that are associated with an eventTime watermark. + val watermarkAttributes = aggregate.groupingExpressions.collect { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a + } - // We can append rows to the sink once the group is under the watermark. Without this - // watermark a group is never "finished" so we would never output anything. - if (watermarkAttributes.isEmpty) { - throwError( - s"$outputMode output mode not supported when there are streaming aggregations on " + + // We can append rows to the sink once the group is under the watermark. Without this + // watermark a group is never "finished" so we would never output anything. + if (watermarkAttributes.isEmpty) { + throwError( + s"$outputMode output mode not supported when there are streaming aggregations on " + s"streaming DataFrames/DataSets without watermark")(plan) - } + } + }) case InternalOutputModes.Complete if aggregates.isEmpty => throwError( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 28a164b5d0cad..632deb4a56a7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -101,11 +101,17 @@ class UnsupportedOperationsSuite extends SparkFunSuite { Update) assertNotSupportedInStreamingPlan( - "aggregate - multiple streaming aggregations", + "aggregate - multiple streaming aggregations in update mode", Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), streamRelation)), outputMode = Update, expectedMsgs = Seq("multiple streaming aggregations")) + assertNotSupportedInStreamingPlan( + "aggregate - multiple streaming aggregations in complete mode", + Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), streamRelation)), + outputMode = Complete, + expectedMsgs = Seq("multiple streaming aggregations")) + assertSupportedInStreamingPlan( "aggregate - streaming aggregations in update mode", Aggregate(Nil, aggExprs("d"), streamRelation), @@ -127,6 +133,32 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, expectedMsgs = Seq("streaming aggregations", "without watermark")) + assertSupportedInStreamingPlan( + "aggregate - multiple streaming aggregations in append mode with watermark", + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), + Aggregate(Seq(attributeWithWatermark), aggExprs("d"), streamRelation)), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "aggregate - multiple streaming aggregations without watermark in append mode", + Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), streamRelation)), + outputMode = Append, + expectedMsgs = Seq("streaming aggregations", "without watermark")) + + assertNotSupportedInStreamingPlan( + "aggregate - multiple streaming aggregations, watermark for second aggregate in append mode", + Aggregate(Nil, aggExprs("c"), + Aggregate(Seq(attributeWithWatermark), aggExprs("d"), streamRelation)), + outputMode = Append, + expectedMsgs = Seq("streaming aggregations", "without watermark")) + + assertNotSupportedInStreamingPlan( + "aggregate - multiple streaming aggregations, watermark for first aggregate in append mode", + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), + Aggregate(Nil, aggExprs("d"), streamRelation)), + outputMode = Append, + expectedMsgs = Seq("streaming aggregations", "without watermark")) + // Aggregation: Distinct aggregates not supported on streaming relation val distinctAggExprs = Seq(Count("*").toAggregateExpression(isDistinct = true).as("c")) assertSupportedInStreamingPlan( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala index 0063318db332d..1d94eecefcdc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala @@ -77,7 +77,8 @@ object CommitLog { } -case class CommitMetadata(nextBatchWatermarkMs: Long = 0) { +case class CommitMetadata(nextBatchWatermarkMs: Long = 0, + operatorWatermarks: Map[Long, Long] = Map.empty) { def json: String = Serialization.write(this)(CommitMetadata.format) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index bfe7d00f56048..eef7fbe11a230 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -85,7 +85,7 @@ case class FlatMapGroupsWithStateExec( true // Always run batches to process timeouts case EventTimeTimeout => // Process another non-data batch only if the watermark has changed in this executed plan - eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermark.isDefined && getWatermark(newMetadata) > eventTimeWatermark.get case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index af52af0d1d7e6..58d8e1fd384f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -101,6 +101,12 @@ class IncrementalExecution( numStateStores) } + // the watermark for the stateful operation if available or the batch watermark + private def getWatermark(operatorId: Long): Long = { + offsetSeqMetadata.operatorWatermarks + .getOrElse(operatorId, offsetSeqMetadata.batchWatermarkMs) + } + /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { @@ -108,12 +114,12 @@ class IncrementalExecution( case StateStoreSaveExec(keys, None, None, None, stateFormatVersion, UnaryExecNode(agg, StateStoreRestoreExec(_, None, _, child))) => - val aggStateInfo = nextStatefulOperationStateInfo + val aggStateInfo = nextStatefulOperationStateInfo() StateStoreSaveExec( keys, Some(aggStateInfo), Some(outputMode), - Some(offsetSeqMetadata.batchWatermarkMs), + Some(getWatermark(aggStateInfo.operatorId)), stateFormatVersion, agg.withNewChildren( StateStoreRestoreExec( @@ -123,30 +129,33 @@ class IncrementalExecution( child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => + val stateInfo = nextStatefulOperationStateInfo() StreamingDeduplicateExec( keys, child, - Some(nextStatefulOperationStateInfo), - Some(offsetSeqMetadata.batchWatermarkMs)) + Some(stateInfo), + Some(getWatermark(stateInfo.operatorId))) case m: FlatMapGroupsWithStateExec => + val stateInfo = nextStatefulOperationStateInfo() m.copy( - stateInfo = Some(nextStatefulOperationStateInfo), + stateInfo = Some(stateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), - eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)) + eventTimeWatermark = Some(getWatermark(stateInfo.operatorId))) case j: StreamingSymmetricHashJoinExec => + val stateInfo = nextStatefulOperationStateInfo() j.copy( - stateInfo = Some(nextStatefulOperationStateInfo), - eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs), + stateInfo = Some(stateInfo), + eventTimeWatermark = Some(getWatermark(stateInfo.operatorId)), stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, - Some(offsetSeqMetadata.batchWatermarkMs))) + Some(getWatermark(stateInfo.operatorId)))) case l: StreamingGlobalLimitExec => l.copy( - stateInfo = Some(nextStatefulOperationStateInfo), + stateInfo = Some(nextStatefulOperationStateInfo()), outputMode = Some(outputMode)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 38ecb0dd12daa..78d2caf6f945d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -268,9 +268,11 @@ class MicroBatchExecution( nextOffsets.metadata.foreach { metadata => OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( - metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) + metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf, + metadata.operatorWatermarks) watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) watermarkTracker.setWatermark(metadata.batchWatermarkMs) + watermarkTracker.setOperatorWatermarks(metadata.operatorWatermarks) } /* identify the current batch id: if commit log indicates we successfully processed the @@ -297,6 +299,7 @@ class MicroBatchExecution( committedOffsets ++= availableOffsets watermarkTracker.setWatermark( math.max(watermarkTracker.currentWatermark, commitMetadata.nextBatchWatermarkMs)) + watermarkTracker.setOperatorWatermarks(commitMetadata.operatorWatermarks) } else if (latestCommittedBatchId < latestBatchId - 1) { logWarning(s"Batch completion log latest batch id is " + s"${latestCommittedBatchId}, which is not trailing " + @@ -369,7 +372,8 @@ class MicroBatchExecution( // Update the query metadata offsetSeqMetadata = offsetSeqMetadata.copy( batchWatermarkMs = watermarkTracker.currentWatermark, - batchTimestampMs = triggerClock.getTimeMillis()) + batchTimestampMs = triggerClock.getTimeMillis(), + operatorWatermarks = watermarkTracker.currentOperatorWatermarks) // Check whether next batch should be constructed val lastExecutionRequiresAnotherBatch = noDataBatchesEnabled && @@ -559,7 +563,8 @@ class MicroBatchExecution( withProgressLocked { sinkCommitProgress = batchSinkProgress watermarkTracker.updateWatermark(lastExecution.executedPlan) - commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark)) + commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark, + watermarkTracker.currentOperatorWatermarks)) committedOffsets ++= availableOffsets } logDebug(s"Completed batch ${currentBatchId}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 73cf355dbe758..8645f7ff0579e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -77,11 +77,13 @@ object OffsetSeq { * @param batchTimestampMs: The current batch processing timestamp. * Time unit: milliseconds * @param conf: Additional conf_s to be persisted across batches, e.g. number of shuffle partitions. + * @param operatorWatermarks: the watermarks at the individual stateful operators. */ case class OffsetSeqMetadata( batchWatermarkMs: Long = 0, batchTimestampMs: Long = 0, - conf: Map[String, String] = Map.empty) { + conf: Map[String, String] = Map.empty, + operatorWatermarks: Map[Long, Long] = Map.empty) { def json: String = Serialization.write(this)(OffsetSeqMetadata.format) } @@ -114,9 +116,10 @@ object OffsetSeqMetadata extends Logging { def apply( batchWatermarkMs: Long, batchTimestampMs: Long, - sessionConf: RuntimeConfig): OffsetSeqMetadata = { + sessionConf: RuntimeConfig, + operatorWatermarks: Map[Long, Long]): OffsetSeqMetadata = { val confs = relevantSQLConfs.map { conf => conf.key -> sessionConf.get(conf.key) }.toMap - OffsetSeqMetadata(batchWatermarkMs, batchTimestampMs, confs) + OffsetSeqMetadata(batchWatermarkMs, batchTimestampMs, confs, operatorWatermarks) } /** Set the SparkSession configuration with the values in the metadata */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 83824f40ab90b..5d46fa9e18fb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -134,7 +134,8 @@ abstract class StreamExecution( /** Metadata associated with the offset seq of a batch in the query. */ protected var offsetSeqMetadata = OffsetSeqMetadata( - batchWatermarkMs = 0, batchTimestampMs = 0, sparkSession.conf) + batchWatermarkMs = 0, batchTimestampMs = 0, sessionConf = sparkSession.conf, + operatorWatermarks = Map.empty) /** * A map of current watermarks, keyed by the position of the watermark operator in the @@ -277,7 +278,8 @@ abstract class StreamExecution( // Disable cost-based join optimization as we do not want stateful operations to be rearranged sparkSessionForStream.conf.set(SQLConf.CBO_ENABLED.key, "false") offsetSeqMetadata = OffsetSeqMetadata( - batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionForStream.conf) + batchWatermarkMs = 0, batchTimestampMs = 0, sessionConf = sparkSessionForStream.conf, + operatorWatermarks = Map.empty) if (state.compareAndSet(INITIALIZING, ACTIVE)) { // Unblock `awaitInitialization` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 50cf971e4ec3c..b82a0905bfa33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -193,7 +193,7 @@ case class StreamingSymmetricHashJoinExec( // Latest watermark value is more than that used in this previous executed plan val watermarkHasChanged = - eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermark.isDefined && getWatermark(newMetadata) > eventTimeWatermark.get watermarkUsedForStateCleanup && watermarkHasChanged } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala index 76ab1284633b1..8a5770d7b07be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf /** @@ -80,34 +80,84 @@ case object MaxWatermark extends MultipleWatermarkPolicy { /** Tracks the watermark value of a streaming query based on a given `policy` */ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() + private val statefulOperatorToWatermark = mutable.HashMap[Long, Long]() + private val statefulOperatorToEventTimeMap = mutable.HashMap[Long, mutable.HashMap[Int, Long]]() + private var globalWatermarkMs: Long = 0 + private def updateWaterMarkMap(eventTimeExecs: Seq[EventTimeWatermarkExec], + map: mutable.HashMap[Int, Long]): Unit = { + eventTimeExecs.zipWithIndex.foreach { + case (e, index) if e.eventTimeStats.value.count > 0 => + logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") + val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs + val prevWatermarkMs = map.get(index) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + map.put(index, newWatermarkMs) + } + + // Populate 0 if we haven't seen any data yet for this watermark node. + case (_, index) => + if (!map.isDefinedAt(index)) { + map.put(index, 0) + } + } + } + def setWatermark(newWatermarkMs: Long): Unit = synchronized { globalWatermarkMs = newWatermarkMs } + def setOperatorWatermarks(operatorWatermarks: Map[Long, Long]): Unit = synchronized { + statefulOperatorToWatermark ++= operatorWatermarks + } + def updateWatermark(executedPlan: SparkPlan): Unit = synchronized { val watermarkOperators = executedPlan.collect { case e: EventTimeWatermarkExec => e } if (watermarkOperators.isEmpty) return - watermarkOperators.zipWithIndex.foreach { - case (e, index) if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") - val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs - val prevWatermarkMs = operatorToWatermarkMap.get(index) - if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { - operatorToWatermarkMap.put(index, newWatermarkMs) - } + updateWaterMarkMap(watermarkOperators, operatorToWatermarkMap) - // Populate 0 if we haven't seen any data yet for this watermark node. - case (_, index) => - if (!operatorToWatermarkMap.isDefinedAt(index)) { - operatorToWatermarkMap.put(index, 0) - } + // compute the per stateful operator watermark + val statefulOperators = executedPlan.collect { + case s: StatefulOperator => s } + statefulOperators.foreach(statefulOperator => { + // find the first event time child node(s) + val eventTimeExecs = statefulOperator match { + case op: UnaryExecNode => + op.collectFirst { + case e: EventTimeWatermarkExec => e + }.map(Seq(_)).getOrElse(Seq()) + case op: BinaryExecNode => + val left = op.left.collectFirst { + case e: EventTimeWatermarkExec => e + }.map(Seq(_)).getOrElse(Seq()) + val right = op.right.collectFirst { + case e: EventTimeWatermarkExec => e + }.map(Seq(_)).getOrElse(Seq()) + left ++ right + } + + // compute watermark for the stateful operator node + statefulOperator.stateInfo.foreach(state => { + if (eventTimeExecs.nonEmpty) { + updateWaterMarkMap(eventTimeExecs, + statefulOperatorToEventTimeMap.getOrElseUpdate(state.operatorId, + new mutable.HashMap[Int, Long]())) + val newWatermarkMs = statefulOperatorToEventTimeMap(state.operatorId).values.toSeq.min + val prevWatermarkMs = statefulOperatorToWatermark.get(state.operatorId) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + statefulOperatorToWatermark.put(state.operatorId, newWatermarkMs) + } + } + }) + }) + + // Update the global watermark to the minimum of all watermark nodes. // This is the safest option, because only the global watermark is fault-tolerant. Making // it the minimum of all individual watermarks guarantees it will never advance past where @@ -121,7 +171,16 @@ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { } } + def statefulOperatorWatermark(id: Long): Option[Long] = synchronized { + statefulOperatorToWatermark.get(id) + } + def currentWatermark: Long = synchronized { globalWatermarkMs } + + def currentOperatorWatermarks: Map[Long, Long] = synchronized { + statefulOperatorToWatermark.toMap + } + } object WatermarkTracker { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index c11af345b0248..2a3b8a6e39a21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -120,6 +120,13 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => } } + /** + * Gets watermark associated with the state operator + */ + protected def getWatermark(metadata: OffsetSeqMetadata): Long = stateInfo + .flatMap(s => metadata.operatorWatermarks.get(s.operatorId)) + .getOrElse(metadata.batchWatermarkMs) + private def stateStoreCustomMetrics: Map[String, SQLMetric] = { val provider = StateStoreProvider.create(sqlContext.conf.stateStoreProviderClass) provider.supportedCustomMetrics.map { @@ -420,7 +427,7 @@ case class StateStoreSaveExec( override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { (outputMode.contains(Append) || outputMode.contains(Update)) && eventTimeWatermark.isDefined && - newMetadata.batchWatermarkMs > eventTimeWatermark.get + getWatermark(newMetadata) > eventTimeWatermark.get } } @@ -490,7 +497,7 @@ case class StreamingDeduplicateExec( override def outputPartitioning: Partitioning = child.outputPartitioning override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { - eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermark.isDefined && getWatermark(newMetadata) > eventTimeWatermark.get } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index e6cdc063c4e9f..13a4ce30a4556 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -55,10 +55,16 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { assert(OffsetSeqMetadata(0, 2, getConfWith(shufflePartitions = 3)) === OffsetSeqMetadata(s"""{"batchTimestampMs":2,"conf": {"$key":3}}""")) - // All set + // Three set assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) === OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}}""")) + // All set + assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3), Map(0L -> 1000L)) === + OffsetSeqMetadata( + s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}, + |"operatorWatermarks": {"0": 1000}}""".stripMargin)) + // Drop unknown fields assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) === OffsetSeqMetadata( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index c696204cecc2c..ffa55bd99bc7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -302,6 +302,71 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche ) } + test("multiple aggregates in append mode") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("inputtime", $"value".cast("timestamp")) + .withWatermark("inputtime", "10 seconds") + .groupBy(window($"inputtime", "5 seconds") as 'window1, $"inputtime").count() + .select($"window1.end".as("windowtime"), $"count".as("num")) + .withWatermark("windowtime", "5 seconds") + .groupBy(window($"windowtime", "5 seconds") as 'window2, $"num").count() + .select($"window2.start".cast("long").as[Long], $"num", $"count") + + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 11, 12, 12), + CheckNewAnswer(), + AddData(inputData, 25), // watermark -> group1 = 15, group2 = 10 + CheckNewAnswer(), + assertNumTotalStateRows(3), + AddData(inputData, 26, 26, 27), + CheckNewAnswer(), + AddData(inputData, 40), // watermark -> group1 = 30 , group2 = 25 + CheckNewAnswer((15, 1, 1), (15, 2, 2)) + ) + } + + test("multiple aggregates in append mode recovery") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("inputtime", $"value".cast("timestamp")) + .withWatermark("inputtime", "10 seconds") + .groupBy(window($"inputtime", "5 seconds") as 'window1, $"inputtime").count() + .select($"window1.end".as("windowtime"), $"count".as("num")) + .withWatermark("windowtime", "5 seconds") + .groupBy(window($"windowtime", "5 seconds") as 'window2, $"num").count() + .select($"window2.start".cast("long").as[Long], $"num", $"count") + + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 11, 12, 12), + CheckNewAnswer(), + AddData(inputData, 25), // watermark -> group1 = 15, group2 = 10 + // window1 [10-15] (10, 1) (11, 2), (12, 2) + // window2 [15-20] (1, 1), (2, 2) + CheckNewAnswer(), + AddData(inputData, 26, 26, 27), + CheckNewAnswer(), + AddData(inputData, 40), // watermark -> group1 = 30 , group2 = 25 + // window1 [25-30] (25, 1), (26, 2), (27, 1) + // window2 [30-35] (1, 2), (2, 1) + CheckNewAnswer((15, 1, 1), (15, 2, 2)), + StopStream, + AssertOnQuery { q => // purge commit and clear the sink + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purge(commit) + q.sink.asInstanceOf[MemorySink].clear() + true + }, + StartStream(), + AddData(inputData, 55), // watermark -> group1 = 45, group2 = 40 + // window1 -> [40-45] (40, 1) + // window2 -> [45-50] (1, 1) + CheckAnswer((30, 1, 2), (30, 2, 1)) + ) + } + test("delay in months and years handled correctly") { val currentTimeMs = System.currentTimeMillis val currentTime = new Date(currentTimeMs) @@ -727,6 +792,13 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche true } + private def assertNumTotalStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => + q.processAllAvailable() + val progressWithData = q.recentProgress.lastOption.get + assert(progressWithData.stateOperators.map(_.numRowsTotal).sum === numTotalRows) + true + } + /** Assert event stats generated on that last batch with data in it */ private def assertEventStats(body: ju.Map[String, String] => Unit): AssertOnQuery = { Execute("AssertEventStats") { q => From 3dc918c23a43074b260afe3fd6125f049820b60b Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Fri, 18 Jan 2019 15:03:31 -0800 Subject: [PATCH 2/4] Address review comments --- .../StreamingSymmetricHashJoinExec.scala | 2 +- .../streaming/EventTimeWatermarkSuite.scala | 135 +++++++++++++----- 2 files changed, 99 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index b82a0905bfa33..4c40f7fe1dd58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -193,7 +193,7 @@ case class StreamingSymmetricHashJoinExec( // Latest watermark value is more than that used in this previous executed plan val watermarkHasChanged = - eventTimeWatermark.isDefined && getWatermark(newMetadata) > eventTimeWatermark.get + eventTimeWatermark.isDefined && getWatermark(newMetadata) > eventTimeWatermark.get watermarkUsedForStateCleanup && watermarkHasChanged } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index ffa55bd99bc7b..e38ba002859d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -303,8 +303,47 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche } test("multiple aggregates in append mode") { + val inputData = MemoryStream[(Int, String)] + + // compute per-user event counts over 5 sec windows + // and sum this to global counts over 30 second windows + val windowedAggregation = inputData.toDS() + .select($"_1".cast("timestamp").as("inputtime"), $"_2".as("user")) + .withWatermark("inputtime", "2 seconds") + .groupBy(window($"inputtime", "5 seconds") as 'window1, $"user").count() + .select($"window1.end".as("windowtime"), $"count") + .withWatermark("windowtime", "3 seconds") + .groupBy(window($"windowtime", "30 seconds") as 'window2).sum("count") + .select($"window2.start".cast("long").as[Long], + $"window2.end".cast("long").as[Long], $"sum(count)") + + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + testStream(windowedAggregation)( + AddData(inputData, (1, "A"), (2, "B"), (4, "A"), (1, "B"), (3, "B")), + CheckNewAnswer(), + AddData(inputData, (8, "A")), // window1 [0 - 5] -> (A, 2), (B, 3) emitted down + // window1 [5 - 10] -> (A, 1) retained in state1 + // window2 [0 - 30] -> 5 is retained in state2 + CheckNewAnswer(), + AddData(inputData, (9, "A"), (9, "B"), (34, "B")), + // window1 [5 - 10] -> (A, 2), (B, 1) emitted down + // window1 [30 - 35] -> (B, 1) retained in state1 + // window2 [0 - 30] -> 8 retained in state2 + CheckNewAnswer(), + AddData(inputData, (39, "B")), // window1 [30 - 35] -> (B, 1) is emitted down + // window1 [35 - 40] -> (B, 1) is retained in state1 + // window2 [0 - 30] -> 8 is emitted out + // window2 [30 - 60] -> 1 is retained in state2 + CheckNewAnswer((0, 30, 8)) + ) + } + } + + test("multiple aggregates in append mode with groups in second window") { val inputData = MemoryStream[Int] + // compute a count of timestamps over 5 sec windows + // and count of counts over 5 sec windows in the second aggregate val windowedAggregation = inputData.toDF() .withColumn("inputtime", $"value".cast("timestamp")) .withWatermark("inputtime", "10 seconds") @@ -314,22 +353,36 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche .groupBy(window($"windowtime", "5 seconds") as 'window2, $"num").count() .select($"window2.start".cast("long").as[Long], $"num", $"count") - testStream(windowedAggregation)( - AddData(inputData, 10, 11, 11, 12, 12), - CheckNewAnswer(), - AddData(inputData, 25), // watermark -> group1 = 15, group2 = 10 - CheckNewAnswer(), - assertNumTotalStateRows(3), - AddData(inputData, 26, 26, 27), - CheckNewAnswer(), - AddData(inputData, 40), // watermark -> group1 = 30 , group2 = 25 - CheckNewAnswer((15, 1, 1), (15, 2, 2)) - ) + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 11, 12, 12), + CheckNewAnswer(), + AddData(inputData, 25), // watermark -> group1 = 15, group2 = 10 + // window1 [10 - 15] -> (10,1) (11,2), (12,2) is emitted + // downstream since watermark of group1 is at 15 + // window1 [25 - 30] -> (25,1) is retained in state1 + // window2 [15 - 20] -> (1,1), (2,2) is retained in state2 + // since watermark of group2 is at 10 + CheckNewAnswer(), + assertNumTotalStateRows(3), // {[25-30],25} -> 1 in state1 and + // {[30-35],1} -> 1, {[30-35],1} -> 1 {[30-35],2} -> 2 in state2 + AddData(inputData, 26, 26, 27), + CheckNewAnswer(), + AddData(inputData, 40), // watermark -> group1 = 30 , group2 = 25 + // window1 [25 - 30] -> (25,1), (26,2), (27,1) is emitted down + // window1 [40 - 45] -> (40, 1) is retained in state1 + // window2 [15 - 20] -> (1,1), (2,2) is now emitted out + // window2 [30 - 35] -> (1,2), (2,1) is retained in state2 + CheckNewAnswer((15, 1, 1), (15, 2, 2)) + ) + } } test("multiple aggregates in append mode recovery") { val inputData = MemoryStream[Int] + // compute a count of timestamps over 5 sec windows + // and count of counts over 5 sec windows in the second aggregate val windowedAggregation = inputData.toDF() .withColumn("inputtime", $"value".cast("timestamp")) .withWatermark("inputtime", "10 seconds") @@ -339,32 +392,40 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche .groupBy(window($"windowtime", "5 seconds") as 'window2, $"num").count() .select($"window2.start".cast("long").as[Long], $"num", $"count") - testStream(windowedAggregation)( - AddData(inputData, 10, 11, 11, 12, 12), - CheckNewAnswer(), - AddData(inputData, 25), // watermark -> group1 = 15, group2 = 10 - // window1 [10-15] (10, 1) (11, 2), (12, 2) - // window2 [15-20] (1, 1), (2, 2) - CheckNewAnswer(), - AddData(inputData, 26, 26, 27), - CheckNewAnswer(), - AddData(inputData, 40), // watermark -> group1 = 30 , group2 = 25 - // window1 [25-30] (25, 1), (26, 2), (27, 1) - // window2 [30-35] (1, 2), (2, 1) - CheckNewAnswer((15, 1, 1), (15, 2, 2)), - StopStream, - AssertOnQuery { q => // purge commit and clear the sink - val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) - q.commitLog.purge(commit) - q.sink.asInstanceOf[MemorySink].clear() - true - }, - StartStream(), - AddData(inputData, 55), // watermark -> group1 = 45, group2 = 40 - // window1 -> [40-45] (40, 1) - // window2 -> [45-50] (1, 1) - CheckAnswer((30, 1, 2), (30, 2, 1)) - ) + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 11, 12, 12), + CheckNewAnswer(), + AddData(inputData, 25), // watermark -> group1 = 15, group2 = 10 + // window1 [10 - 15] -> (10,1) (11,2), (12,2) is emitted + // downstream since watermark of group1 is at 15 + // window1 [25 - 30] -> (25,1) is retained in state1 + // window2 [15 - 20] -> (1,1), (2,2) is retained in state2 + CheckNewAnswer(), + AddData(inputData, 26, 26, 27), + CheckNewAnswer(), + AddData(inputData, 40), // watermark -> group1 = 30 , group2 = 25 + // window1 [25 - 30] -> (25,1), (26,2), (27,1) is emitted down + // window1 [40 - 45] -> (40, 1) is retained in state1 + // window2 [15 - 20] -> (1,1), (2,2) is now emitted out + // window2 [30 - 35] -> (1,2), (2,1) is retained in state2 + CheckNewAnswer((15, 1, 1), (15, 2, 2)), + StopStream, + AssertOnQuery { q => // purge commit and clear the sink + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purge(commit) + q.sink.asInstanceOf[MemorySink].clear() + true + }, + StartStream(), + AddData(inputData, 55), // watermark -> group1 = 45, group2 = 40 + // window1 [40 - 45] -> (40,1) is emitted down + // window1 [55 - 60] -> (55, 1) is retained in state1 + // window2 [30 - 35] -> (1,2), (2,1) is emitted out + // window2 [40 - 45] -> (40, 1) is retained in state2 + CheckAnswer((30, 1, 2), (30, 2, 1)) + ) + } } test("delay in months and years handled correctly") { From 91046cd5190ce5afaa8776be3d649b933df3acd9 Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Thu, 24 Jan 2019 15:21:37 -0800 Subject: [PATCH 3/4] Review comments --- .../catalyst/analysis/UnsupportedOperationChecker.scala | 4 ++-- .../spark/sql/execution/streaming/WatermarkTracker.scala | 8 ++++---- .../spark/sql/streaming/EventTimeWatermarkSuite.scala | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 4ca1084762254..c2685625cf1d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -96,7 +96,7 @@ object UnsupportedOperationChecker { // Disallow some output mode outputMode match { case InternalOutputModes.Append if aggregates.nonEmpty => - aggregates.foreach(aggregate => { + aggregates.foreach { aggregate => // Find any attributes that are associated with an eventTime watermark. val watermarkAttributes = aggregate.groupingExpressions.collect { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a @@ -109,7 +109,7 @@ object UnsupportedOperationChecker { s"$outputMode output mode not supported when there are streaming aggregations on " + s"streaming DataFrames/DataSets without watermark")(plan) } - }) + } case InternalOutputModes.Complete if aggregates.isEmpty => throwError( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala index 8a5770d7b07be..c4d07b1b0e68c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -125,7 +125,7 @@ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { case s: StatefulOperator => s } - statefulOperators.foreach(statefulOperator => { + statefulOperators.foreach { statefulOperator => // find the first event time child node(s) val eventTimeExecs = statefulOperator match { case op: UnaryExecNode => @@ -143,7 +143,7 @@ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { } // compute watermark for the stateful operator node - statefulOperator.stateInfo.foreach(state => { + statefulOperator.stateInfo.foreach { state => if (eventTimeExecs.nonEmpty) { updateWaterMarkMap(eventTimeExecs, statefulOperatorToEventTimeMap.getOrElseUpdate(state.operatorId, @@ -154,8 +154,8 @@ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { statefulOperatorToWatermark.put(state.operatorId, newWatermarkMs) } } - }) - }) + } + } // Update the global watermark to the minimum of all watermark nodes. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index e38ba002859d9..27657e6db173a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -364,8 +364,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // window2 [15 - 20] -> (1,1), (2,2) is retained in state2 // since watermark of group2 is at 10 CheckNewAnswer(), - assertNumTotalStateRows(3), // {[25-30],25} -> 1 in state1 and - // {[30-35],1} -> 1, {[30-35],1} -> 1 {[30-35],2} -> 2 in state2 + assertNumTotalStateRows(3), // [25 - 30] -> (25,1) in state1 and + // [15 - 20] -> (1,1), (2,2) in state2 AddData(inputData, 26, 26, 27), CheckNewAnswer(), AddData(inputData, 40), // watermark -> group1 = 30 , group2 = 25 @@ -422,7 +422,7 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // window1 [40 - 45] -> (40,1) is emitted down // window1 [55 - 60] -> (55, 1) is retained in state1 // window2 [30 - 35] -> (1,2), (2,1) is emitted out - // window2 [40 - 45] -> (40, 1) is retained in state2 + // window2 [40 - 45] -> (1, 1) is retained in state2 CheckAnswer((30, 1, 2), (30, 2, 1)) ) } From 8579fb614317b4dd8a21ac9f50297cd864bcb20c Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Fri, 15 Feb 2019 16:21:58 -0800 Subject: [PATCH 4/4] Track downstream watermark based on upstream --- .../streaming/WatermarkTracker.scala | 107 ++++++++++-------- 1 file changed, 62 insertions(+), 45 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala index c4d07b1b0e68c..9651b4dc06d2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable @@ -85,22 +86,31 @@ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { private var globalWatermarkMs: Long = 0 - private def updateWaterMarkMap(eventTimeExecs: Seq[EventTimeWatermarkExec], - map: mutable.HashMap[Int, Long]): Unit = { - eventTimeExecs.zipWithIndex.foreach { - case (e, index) if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") - val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs - val prevWatermarkMs = map.get(index) - if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { - map.put(index, newWatermarkMs) - } - + private def updateWaterMarkMap( + eventTimeExec: EventTimeWatermarkExec, + index: Int, + map: mutable.HashMap[Int, Long]): Option[Long] = { + if (eventTimeExec.eventTimeStats.value.count > 0) { + logDebug(s"Observed event time stats $index: ${eventTimeExec.eventTimeStats.value}") + val newWatermarkMs = eventTimeExec.eventTimeStats.value.max - eventTimeExec.delayMs + val prevWatermarkMs = map.get(index) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + map.put(index, newWatermarkMs) + } + } else { // Populate 0 if we haven't seen any data yet for this watermark node. - case (_, index) => - if (!map.isDefinedAt(index)) { - map.put(index, 0) - } + if (!map.isDefinedAt(index)) { + map.put(index, 0) + } + } + map.get(index) + } + + private def updateWaterMarkMap( + eventTimeExecs: Seq[EventTimeWatermarkExec], + map: mutable.HashMap[Int, Long]): Unit = { + eventTimeExecs.zipWithIndex.foreach { + case (e, i) => updateWaterMarkMap(e, i, map) } } @@ -120,44 +130,51 @@ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { updateWaterMarkMap(watermarkOperators, operatorToWatermarkMap) + // compute watermark of an operator node + def computeWatermark( + node: SparkPlan, + stateOperatorId: Long, + etId: AtomicInteger): Option[Long] = { + node match { + case ws: WatermarkSupport => + ws.eventTimeWatermark + case et: EventTimeWatermarkExec => + updateWaterMarkMap(et, etId.getAndIncrement(), + statefulOperatorToEventTimeMap.getOrElseUpdate(stateOperatorId, + new mutable.HashMap[Int, Long]())) + case other => + // min of available watermarks + val watermarks = other.children + .map(c => computeWatermark(c, stateOperatorId, etId)) + .filter(_.isDefined) + .map(_.get) + if (watermarks.isEmpty) None else Some(watermarks.min) + } + } + // compute the per stateful operator watermark val statefulOperators = executedPlan.collect { case s: StatefulOperator => s } - - statefulOperators.foreach { statefulOperator => - // find the first event time child node(s) - val eventTimeExecs = statefulOperator match { - case op: UnaryExecNode => - op.collectFirst { - case e: EventTimeWatermarkExec => e - }.map(Seq(_)).getOrElse(Seq()) - case op: BinaryExecNode => - val left = op.left.collectFirst { - case e: EventTimeWatermarkExec => e - }.map(Seq(_)).getOrElse(Seq()) - val right = op.right.collectFirst { - case e: EventTimeWatermarkExec => e - }.map(Seq(_)).getOrElse(Seq()) - left ++ right - } - - // compute watermark for the stateful operator node - statefulOperator.stateInfo.foreach { state => - if (eventTimeExecs.nonEmpty) { - updateWaterMarkMap(eventTimeExecs, - statefulOperatorToEventTimeMap.getOrElseUpdate(state.operatorId, - new mutable.HashMap[Int, Long]())) - val newWatermarkMs = statefulOperatorToEventTimeMap(state.operatorId).values.toSeq.min - val prevWatermarkMs = statefulOperatorToWatermark.get(state.operatorId) - if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { - statefulOperatorToWatermark.put(state.operatorId, newWatermarkMs) - } + statefulOperators.foreach { statefulOp => + statefulOp.stateInfo.foreach { stateInfo => + val newWatermarkMs = if (statefulOp.children.isEmpty) { + 0 + } else { + val etId = new AtomicInteger(0) + val watermarks = statefulOp.children + .map(c => computeWatermark(c, stateInfo.operatorId, etId)) + .filter(_.isDefined) + .map(_.get) + if (watermarks.isEmpty) 0 else watermarks.min + } + val prevWatermarkMs = statefulOperatorToWatermark.get(stateInfo.operatorId) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + statefulOperatorToWatermark.put(stateInfo.operatorId, newWatermarkMs) } } } - // Update the global watermark to the minimum of all watermark nodes. // This is the safest option, because only the global watermark is fault-tolerant. Making // it the minimum of all individual watermarks guarantees it will never advance past where