diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 41ba6d34b5499..c2685625cf1d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -84,10 +84,10 @@ object UnsupportedOperationChecker { " or the output mode is not append on a streaming DataFrames/Datasets")(plan) } - // Disallow multiple streaming aggregations val aggregates = collectStreamingAggregates(plan) - if (aggregates.size > 1) { + // multiple aggregates are supported only in append mode + if (outputMode != InternalOutputModes.Append && aggregates.size > 1) { throwError( "Multiple streaming aggregations are not supported with " + "streaming DataFrames/Datasets")(plan) @@ -96,19 +96,19 @@ object UnsupportedOperationChecker { // Disallow some output mode outputMode match { case InternalOutputModes.Append if aggregates.nonEmpty => - val aggregate = aggregates.head - - // Find any attributes that are associated with an eventTime watermark. - val watermarkAttributes = aggregate.groupingExpressions.collect { - case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a - } + aggregates.foreach { aggregate => + // Find any attributes that are associated with an eventTime watermark. + val watermarkAttributes = aggregate.groupingExpressions.collect { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a + } - // We can append rows to the sink once the group is under the watermark. Without this - // watermark a group is never "finished" so we would never output anything. - if (watermarkAttributes.isEmpty) { - throwError( - s"$outputMode output mode not supported when there are streaming aggregations on " + + // We can append rows to the sink once the group is under the watermark. Without this + // watermark a group is never "finished" so we would never output anything. + if (watermarkAttributes.isEmpty) { + throwError( + s"$outputMode output mode not supported when there are streaming aggregations on " + s"streaming DataFrames/DataSets without watermark")(plan) + } } case InternalOutputModes.Complete if aggregates.isEmpty => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 28a164b5d0cad..632deb4a56a7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -101,11 +101,17 @@ class UnsupportedOperationsSuite extends SparkFunSuite { Update) assertNotSupportedInStreamingPlan( - "aggregate - multiple streaming aggregations", + "aggregate - multiple streaming aggregations in update mode", Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), streamRelation)), outputMode = Update, expectedMsgs = Seq("multiple streaming aggregations")) + assertNotSupportedInStreamingPlan( + "aggregate - multiple streaming aggregations in complete mode", + Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), streamRelation)), + outputMode = Complete, + expectedMsgs = Seq("multiple streaming aggregations")) + assertSupportedInStreamingPlan( "aggregate - streaming aggregations in update mode", Aggregate(Nil, aggExprs("d"), streamRelation), @@ -127,6 +133,32 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, expectedMsgs = Seq("streaming aggregations", "without watermark")) + assertSupportedInStreamingPlan( + "aggregate - multiple streaming aggregations in append mode with watermark", + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), + Aggregate(Seq(attributeWithWatermark), aggExprs("d"), streamRelation)), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "aggregate - multiple streaming aggregations without watermark in append mode", + Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), streamRelation)), + outputMode = Append, + expectedMsgs = Seq("streaming aggregations", "without watermark")) + + assertNotSupportedInStreamingPlan( + "aggregate - multiple streaming aggregations, watermark for second aggregate in append mode", + Aggregate(Nil, aggExprs("c"), + Aggregate(Seq(attributeWithWatermark), aggExprs("d"), streamRelation)), + outputMode = Append, + expectedMsgs = Seq("streaming aggregations", "without watermark")) + + assertNotSupportedInStreamingPlan( + "aggregate - multiple streaming aggregations, watermark for first aggregate in append mode", + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), + Aggregate(Nil, aggExprs("d"), streamRelation)), + outputMode = Append, + expectedMsgs = Seq("streaming aggregations", "without watermark")) + // Aggregation: Distinct aggregates not supported on streaming relation val distinctAggExprs = Seq(Count("*").toAggregateExpression(isDistinct = true).as("c")) assertSupportedInStreamingPlan( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala index 0063318db332d..1d94eecefcdc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala @@ -77,7 +77,8 @@ object CommitLog { } -case class CommitMetadata(nextBatchWatermarkMs: Long = 0) { +case class CommitMetadata(nextBatchWatermarkMs: Long = 0, + operatorWatermarks: Map[Long, Long] = Map.empty) { def json: String = Serialization.write(this)(CommitMetadata.format) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index bfe7d00f56048..eef7fbe11a230 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -85,7 +85,7 @@ case class FlatMapGroupsWithStateExec( true // Always run batches to process timeouts case EventTimeTimeout => // Process another non-data batch only if the watermark has changed in this executed plan - eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermark.isDefined && getWatermark(newMetadata) > eventTimeWatermark.get case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index af52af0d1d7e6..58d8e1fd384f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -101,6 +101,12 @@ class IncrementalExecution( numStateStores) } + // the watermark for the stateful operation if available or the batch watermark + private def getWatermark(operatorId: Long): Long = { + offsetSeqMetadata.operatorWatermarks + .getOrElse(operatorId, offsetSeqMetadata.batchWatermarkMs) + } + /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { @@ -108,12 +114,12 @@ class IncrementalExecution( case StateStoreSaveExec(keys, None, None, None, stateFormatVersion, UnaryExecNode(agg, StateStoreRestoreExec(_, None, _, child))) => - val aggStateInfo = nextStatefulOperationStateInfo + val aggStateInfo = nextStatefulOperationStateInfo() StateStoreSaveExec( keys, Some(aggStateInfo), Some(outputMode), - Some(offsetSeqMetadata.batchWatermarkMs), + Some(getWatermark(aggStateInfo.operatorId)), stateFormatVersion, agg.withNewChildren( StateStoreRestoreExec( @@ -123,30 +129,33 @@ class IncrementalExecution( child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => + val stateInfo = nextStatefulOperationStateInfo() StreamingDeduplicateExec( keys, child, - Some(nextStatefulOperationStateInfo), - Some(offsetSeqMetadata.batchWatermarkMs)) + Some(stateInfo), + Some(getWatermark(stateInfo.operatorId))) case m: FlatMapGroupsWithStateExec => + val stateInfo = nextStatefulOperationStateInfo() m.copy( - stateInfo = Some(nextStatefulOperationStateInfo), + stateInfo = Some(stateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), - eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)) + eventTimeWatermark = Some(getWatermark(stateInfo.operatorId))) case j: StreamingSymmetricHashJoinExec => + val stateInfo = nextStatefulOperationStateInfo() j.copy( - stateInfo = Some(nextStatefulOperationStateInfo), - eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs), + stateInfo = Some(stateInfo), + eventTimeWatermark = Some(getWatermark(stateInfo.operatorId)), stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, - Some(offsetSeqMetadata.batchWatermarkMs))) + Some(getWatermark(stateInfo.operatorId)))) case l: StreamingGlobalLimitExec => l.copy( - stateInfo = Some(nextStatefulOperationStateInfo), + stateInfo = Some(nextStatefulOperationStateInfo()), outputMode = Some(outputMode)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 38ecb0dd12daa..78d2caf6f945d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -268,9 +268,11 @@ class MicroBatchExecution( nextOffsets.metadata.foreach { metadata => OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( - metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) + metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf, + metadata.operatorWatermarks) watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) watermarkTracker.setWatermark(metadata.batchWatermarkMs) + watermarkTracker.setOperatorWatermarks(metadata.operatorWatermarks) } /* identify the current batch id: if commit log indicates we successfully processed the @@ -297,6 +299,7 @@ class MicroBatchExecution( committedOffsets ++= availableOffsets watermarkTracker.setWatermark( math.max(watermarkTracker.currentWatermark, commitMetadata.nextBatchWatermarkMs)) + watermarkTracker.setOperatorWatermarks(commitMetadata.operatorWatermarks) } else if (latestCommittedBatchId < latestBatchId - 1) { logWarning(s"Batch completion log latest batch id is " + s"${latestCommittedBatchId}, which is not trailing " + @@ -369,7 +372,8 @@ class MicroBatchExecution( // Update the query metadata offsetSeqMetadata = offsetSeqMetadata.copy( batchWatermarkMs = watermarkTracker.currentWatermark, - batchTimestampMs = triggerClock.getTimeMillis()) + batchTimestampMs = triggerClock.getTimeMillis(), + operatorWatermarks = watermarkTracker.currentOperatorWatermarks) // Check whether next batch should be constructed val lastExecutionRequiresAnotherBatch = noDataBatchesEnabled && @@ -559,7 +563,8 @@ class MicroBatchExecution( withProgressLocked { sinkCommitProgress = batchSinkProgress watermarkTracker.updateWatermark(lastExecution.executedPlan) - commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark)) + commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark, + watermarkTracker.currentOperatorWatermarks)) committedOffsets ++= availableOffsets } logDebug(s"Completed batch ${currentBatchId}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 73cf355dbe758..8645f7ff0579e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -77,11 +77,13 @@ object OffsetSeq { * @param batchTimestampMs: The current batch processing timestamp. * Time unit: milliseconds * @param conf: Additional conf_s to be persisted across batches, e.g. number of shuffle partitions. + * @param operatorWatermarks: the watermarks at the individual stateful operators. */ case class OffsetSeqMetadata( batchWatermarkMs: Long = 0, batchTimestampMs: Long = 0, - conf: Map[String, String] = Map.empty) { + conf: Map[String, String] = Map.empty, + operatorWatermarks: Map[Long, Long] = Map.empty) { def json: String = Serialization.write(this)(OffsetSeqMetadata.format) } @@ -114,9 +116,10 @@ object OffsetSeqMetadata extends Logging { def apply( batchWatermarkMs: Long, batchTimestampMs: Long, - sessionConf: RuntimeConfig): OffsetSeqMetadata = { + sessionConf: RuntimeConfig, + operatorWatermarks: Map[Long, Long]): OffsetSeqMetadata = { val confs = relevantSQLConfs.map { conf => conf.key -> sessionConf.get(conf.key) }.toMap - OffsetSeqMetadata(batchWatermarkMs, batchTimestampMs, confs) + OffsetSeqMetadata(batchWatermarkMs, batchTimestampMs, confs, operatorWatermarks) } /** Set the SparkSession configuration with the values in the metadata */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 83824f40ab90b..5d46fa9e18fb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -134,7 +134,8 @@ abstract class StreamExecution( /** Metadata associated with the offset seq of a batch in the query. */ protected var offsetSeqMetadata = OffsetSeqMetadata( - batchWatermarkMs = 0, batchTimestampMs = 0, sparkSession.conf) + batchWatermarkMs = 0, batchTimestampMs = 0, sessionConf = sparkSession.conf, + operatorWatermarks = Map.empty) /** * A map of current watermarks, keyed by the position of the watermark operator in the @@ -277,7 +278,8 @@ abstract class StreamExecution( // Disable cost-based join optimization as we do not want stateful operations to be rearranged sparkSessionForStream.conf.set(SQLConf.CBO_ENABLED.key, "false") offsetSeqMetadata = OffsetSeqMetadata( - batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionForStream.conf) + batchWatermarkMs = 0, batchTimestampMs = 0, sessionConf = sparkSessionForStream.conf, + operatorWatermarks = Map.empty) if (state.compareAndSet(INITIALIZING, ACTIVE)) { // Unblock `awaitInitialization` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 50cf971e4ec3c..4c40f7fe1dd58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -193,7 +193,7 @@ case class StreamingSymmetricHashJoinExec( // Latest watermark value is more than that used in this previous executed plan val watermarkHasChanged = - eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermark.isDefined && getWatermark(newMetadata) > eventTimeWatermark.get watermarkUsedForStateCleanup && watermarkHasChanged } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala index 76ab1284633b1..9651b4dc06d2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.execution.streaming import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf /** @@ -80,32 +81,98 @@ case object MaxWatermark extends MultipleWatermarkPolicy { /** Tracks the watermark value of a streaming query based on a given `policy` */ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() + private val statefulOperatorToWatermark = mutable.HashMap[Long, Long]() + private val statefulOperatorToEventTimeMap = mutable.HashMap[Long, mutable.HashMap[Int, Long]]() + private var globalWatermarkMs: Long = 0 + private def updateWaterMarkMap( + eventTimeExec: EventTimeWatermarkExec, + index: Int, + map: mutable.HashMap[Int, Long]): Option[Long] = { + if (eventTimeExec.eventTimeStats.value.count > 0) { + logDebug(s"Observed event time stats $index: ${eventTimeExec.eventTimeStats.value}") + val newWatermarkMs = eventTimeExec.eventTimeStats.value.max - eventTimeExec.delayMs + val prevWatermarkMs = map.get(index) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + map.put(index, newWatermarkMs) + } + } else { + // Populate 0 if we haven't seen any data yet for this watermark node. + if (!map.isDefinedAt(index)) { + map.put(index, 0) + } + } + map.get(index) + } + + private def updateWaterMarkMap( + eventTimeExecs: Seq[EventTimeWatermarkExec], + map: mutable.HashMap[Int, Long]): Unit = { + eventTimeExecs.zipWithIndex.foreach { + case (e, i) => updateWaterMarkMap(e, i, map) + } + } + def setWatermark(newWatermarkMs: Long): Unit = synchronized { globalWatermarkMs = newWatermarkMs } + def setOperatorWatermarks(operatorWatermarks: Map[Long, Long]): Unit = synchronized { + statefulOperatorToWatermark ++= operatorWatermarks + } + def updateWatermark(executedPlan: SparkPlan): Unit = synchronized { val watermarkOperators = executedPlan.collect { case e: EventTimeWatermarkExec => e } if (watermarkOperators.isEmpty) return - watermarkOperators.zipWithIndex.foreach { - case (e, index) if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") - val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs - val prevWatermarkMs = operatorToWatermarkMap.get(index) - if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { - operatorToWatermarkMap.put(index, newWatermarkMs) - } + updateWaterMarkMap(watermarkOperators, operatorToWatermarkMap) + + // compute watermark of an operator node + def computeWatermark( + node: SparkPlan, + stateOperatorId: Long, + etId: AtomicInteger): Option[Long] = { + node match { + case ws: WatermarkSupport => + ws.eventTimeWatermark + case et: EventTimeWatermarkExec => + updateWaterMarkMap(et, etId.getAndIncrement(), + statefulOperatorToEventTimeMap.getOrElseUpdate(stateOperatorId, + new mutable.HashMap[Int, Long]())) + case other => + // min of available watermarks + val watermarks = other.children + .map(c => computeWatermark(c, stateOperatorId, etId)) + .filter(_.isDefined) + .map(_.get) + if (watermarks.isEmpty) None else Some(watermarks.min) + } + } - // Populate 0 if we haven't seen any data yet for this watermark node. - case (_, index) => - if (!operatorToWatermarkMap.isDefinedAt(index)) { - operatorToWatermarkMap.put(index, 0) + // compute the per stateful operator watermark + val statefulOperators = executedPlan.collect { + case s: StatefulOperator => s + } + statefulOperators.foreach { statefulOp => + statefulOp.stateInfo.foreach { stateInfo => + val newWatermarkMs = if (statefulOp.children.isEmpty) { + 0 + } else { + val etId = new AtomicInteger(0) + val watermarks = statefulOp.children + .map(c => computeWatermark(c, stateInfo.operatorId, etId)) + .filter(_.isDefined) + .map(_.get) + if (watermarks.isEmpty) 0 else watermarks.min + } + val prevWatermarkMs = statefulOperatorToWatermark.get(stateInfo.operatorId) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + statefulOperatorToWatermark.put(stateInfo.operatorId, newWatermarkMs) } + } } // Update the global watermark to the minimum of all watermark nodes. @@ -121,7 +188,16 @@ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { } } + def statefulOperatorWatermark(id: Long): Option[Long] = synchronized { + statefulOperatorToWatermark.get(id) + } + def currentWatermark: Long = synchronized { globalWatermarkMs } + + def currentOperatorWatermarks: Map[Long, Long] = synchronized { + statefulOperatorToWatermark.toMap + } + } object WatermarkTracker { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index c11af345b0248..2a3b8a6e39a21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -120,6 +120,13 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => } } + /** + * Gets watermark associated with the state operator + */ + protected def getWatermark(metadata: OffsetSeqMetadata): Long = stateInfo + .flatMap(s => metadata.operatorWatermarks.get(s.operatorId)) + .getOrElse(metadata.batchWatermarkMs) + private def stateStoreCustomMetrics: Map[String, SQLMetric] = { val provider = StateStoreProvider.create(sqlContext.conf.stateStoreProviderClass) provider.supportedCustomMetrics.map { @@ -420,7 +427,7 @@ case class StateStoreSaveExec( override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { (outputMode.contains(Append) || outputMode.contains(Update)) && eventTimeWatermark.isDefined && - newMetadata.batchWatermarkMs > eventTimeWatermark.get + getWatermark(newMetadata) > eventTimeWatermark.get } } @@ -490,7 +497,7 @@ case class StreamingDeduplicateExec( override def outputPartitioning: Partitioning = child.outputPartitioning override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { - eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermark.isDefined && getWatermark(newMetadata) > eventTimeWatermark.get } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index e6cdc063c4e9f..13a4ce30a4556 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -55,10 +55,16 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { assert(OffsetSeqMetadata(0, 2, getConfWith(shufflePartitions = 3)) === OffsetSeqMetadata(s"""{"batchTimestampMs":2,"conf": {"$key":3}}""")) - // All set + // Three set assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) === OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}}""")) + // All set + assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3), Map(0L -> 1000L)) === + OffsetSeqMetadata( + s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}, + |"operatorWatermarks": {"0": 1000}}""".stripMargin)) + // Drop unknown fields assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) === OffsetSeqMetadata( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index c696204cecc2c..27657e6db173a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -302,6 +302,132 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche ) } + test("multiple aggregates in append mode") { + val inputData = MemoryStream[(Int, String)] + + // compute per-user event counts over 5 sec windows + // and sum this to global counts over 30 second windows + val windowedAggregation = inputData.toDS() + .select($"_1".cast("timestamp").as("inputtime"), $"_2".as("user")) + .withWatermark("inputtime", "2 seconds") + .groupBy(window($"inputtime", "5 seconds") as 'window1, $"user").count() + .select($"window1.end".as("windowtime"), $"count") + .withWatermark("windowtime", "3 seconds") + .groupBy(window($"windowtime", "30 seconds") as 'window2).sum("count") + .select($"window2.start".cast("long").as[Long], + $"window2.end".cast("long").as[Long], $"sum(count)") + + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + testStream(windowedAggregation)( + AddData(inputData, (1, "A"), (2, "B"), (4, "A"), (1, "B"), (3, "B")), + CheckNewAnswer(), + AddData(inputData, (8, "A")), // window1 [0 - 5] -> (A, 2), (B, 3) emitted down + // window1 [5 - 10] -> (A, 1) retained in state1 + // window2 [0 - 30] -> 5 is retained in state2 + CheckNewAnswer(), + AddData(inputData, (9, "A"), (9, "B"), (34, "B")), + // window1 [5 - 10] -> (A, 2), (B, 1) emitted down + // window1 [30 - 35] -> (B, 1) retained in state1 + // window2 [0 - 30] -> 8 retained in state2 + CheckNewAnswer(), + AddData(inputData, (39, "B")), // window1 [30 - 35] -> (B, 1) is emitted down + // window1 [35 - 40] -> (B, 1) is retained in state1 + // window2 [0 - 30] -> 8 is emitted out + // window2 [30 - 60] -> 1 is retained in state2 + CheckNewAnswer((0, 30, 8)) + ) + } + } + + test("multiple aggregates in append mode with groups in second window") { + val inputData = MemoryStream[Int] + + // compute a count of timestamps over 5 sec windows + // and count of counts over 5 sec windows in the second aggregate + val windowedAggregation = inputData.toDF() + .withColumn("inputtime", $"value".cast("timestamp")) + .withWatermark("inputtime", "10 seconds") + .groupBy(window($"inputtime", "5 seconds") as 'window1, $"inputtime").count() + .select($"window1.end".as("windowtime"), $"count".as("num")) + .withWatermark("windowtime", "5 seconds") + .groupBy(window($"windowtime", "5 seconds") as 'window2, $"num").count() + .select($"window2.start".cast("long").as[Long], $"num", $"count") + + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 11, 12, 12), + CheckNewAnswer(), + AddData(inputData, 25), // watermark -> group1 = 15, group2 = 10 + // window1 [10 - 15] -> (10,1) (11,2), (12,2) is emitted + // downstream since watermark of group1 is at 15 + // window1 [25 - 30] -> (25,1) is retained in state1 + // window2 [15 - 20] -> (1,1), (2,2) is retained in state2 + // since watermark of group2 is at 10 + CheckNewAnswer(), + assertNumTotalStateRows(3), // [25 - 30] -> (25,1) in state1 and + // [15 - 20] -> (1,1), (2,2) in state2 + AddData(inputData, 26, 26, 27), + CheckNewAnswer(), + AddData(inputData, 40), // watermark -> group1 = 30 , group2 = 25 + // window1 [25 - 30] -> (25,1), (26,2), (27,1) is emitted down + // window1 [40 - 45] -> (40, 1) is retained in state1 + // window2 [15 - 20] -> (1,1), (2,2) is now emitted out + // window2 [30 - 35] -> (1,2), (2,1) is retained in state2 + CheckNewAnswer((15, 1, 1), (15, 2, 2)) + ) + } + } + + test("multiple aggregates in append mode recovery") { + val inputData = MemoryStream[Int] + + // compute a count of timestamps over 5 sec windows + // and count of counts over 5 sec windows in the second aggregate + val windowedAggregation = inputData.toDF() + .withColumn("inputtime", $"value".cast("timestamp")) + .withWatermark("inputtime", "10 seconds") + .groupBy(window($"inputtime", "5 seconds") as 'window1, $"inputtime").count() + .select($"window1.end".as("windowtime"), $"count".as("num")) + .withWatermark("windowtime", "5 seconds") + .groupBy(window($"windowtime", "5 seconds") as 'window2, $"num").count() + .select($"window2.start".cast("long").as[Long], $"num", $"count") + + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 11, 12, 12), + CheckNewAnswer(), + AddData(inputData, 25), // watermark -> group1 = 15, group2 = 10 + // window1 [10 - 15] -> (10,1) (11,2), (12,2) is emitted + // downstream since watermark of group1 is at 15 + // window1 [25 - 30] -> (25,1) is retained in state1 + // window2 [15 - 20] -> (1,1), (2,2) is retained in state2 + CheckNewAnswer(), + AddData(inputData, 26, 26, 27), + CheckNewAnswer(), + AddData(inputData, 40), // watermark -> group1 = 30 , group2 = 25 + // window1 [25 - 30] -> (25,1), (26,2), (27,1) is emitted down + // window1 [40 - 45] -> (40, 1) is retained in state1 + // window2 [15 - 20] -> (1,1), (2,2) is now emitted out + // window2 [30 - 35] -> (1,2), (2,1) is retained in state2 + CheckNewAnswer((15, 1, 1), (15, 2, 2)), + StopStream, + AssertOnQuery { q => // purge commit and clear the sink + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purge(commit) + q.sink.asInstanceOf[MemorySink].clear() + true + }, + StartStream(), + AddData(inputData, 55), // watermark -> group1 = 45, group2 = 40 + // window1 [40 - 45] -> (40,1) is emitted down + // window1 [55 - 60] -> (55, 1) is retained in state1 + // window2 [30 - 35] -> (1,2), (2,1) is emitted out + // window2 [40 - 45] -> (1, 1) is retained in state2 + CheckAnswer((30, 1, 2), (30, 2, 1)) + ) + } + } + test("delay in months and years handled correctly") { val currentTimeMs = System.currentTimeMillis val currentTime = new Date(currentTimeMs) @@ -727,6 +853,13 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche true } + private def assertNumTotalStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => + q.processAllAvailable() + val progressWithData = q.recentProgress.lastOption.get + assert(progressWithData.stateOperators.map(_.numRowsTotal).sum === numTotalRows) + true + } + /** Assert event stats generated on that last batch with data in it */ private def assertEventStats(body: ju.Map[String, String] => Unit): AssertOnQuery = { Execute("AssertEventStats") { q =>