diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index d689a6f3c9819..fb85e36d8e66a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -353,6 +353,7 @@ case class StateStoreSaveExec( finished = true null } else { + numOutputRows += 1 removedValueRow } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 134e61ed12a21..16e6215eaa6a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.streaming import java.io.File import java.util.{Locale, TimeZone} +import scala.collection.mutable + import org.apache.commons.io.FileUtils import org.scalatest.Assertions @@ -184,7 +186,68 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { ) } - testWithAllStateVersions("state metrics") { + testWithAllStateVersions("state metrics - append mode") { + val inputData = MemoryStream[Int] + val aggWithWatermark = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + implicit class RichStreamExecution(query: StreamExecution) { + // this could be either empty row batch or actual batch + def stateNodes: Seq[SparkPlan] = { + query.lastExecution.executedPlan.collect { + case p if p.isInstanceOf[StateStoreSaveExec] => p + } + } + + def stateOperatorProgresses: Seq[StateOperatorProgress] = { + val operatorProgress = mutable.ArrayBuffer[StateOperatorProgress]() + var progress = query.recentProgress.last + + operatorProgress ++= progress.stateOperators.map { op => op.copy(op.numRowsUpdated) } + if (progress.numInputRows == 0) { + // empty batch, merge metrics from previous batch as well + progress = query.recentProgress.takeRight(2).head + operatorProgress.zipWithIndex.foreach { case (sop, index) => + // "numRowsUpdated" should be merged, as it could be updated in both batches. + // (for now it is only updated from previous batch, but things can be changed.) + // other metrics represent current status of state so picking up the latest values. + val newOperatorProgress = sop.copy( + sop.numRowsUpdated + progress.stateOperators(index).numRowsUpdated) + operatorProgress(index) = newOperatorProgress + } + } + + operatorProgress + } + } + + testStream(aggWithWatermark)( + AddData(inputData, 15), + CheckAnswer(), // watermark = 5 + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics("numOutputRows").value === 0 }, + AssertOnQuery { _.stateOperatorProgresses.head.numRowsUpdated === 1 }, + AssertOnQuery { _.stateOperatorProgresses.head.numRowsTotal === 1 }, + AddData(inputData, 10, 12, 14), + CheckAnswer(), // watermark = 5 + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics("numOutputRows").value === 0 }, + AssertOnQuery { _.stateOperatorProgresses.head.numRowsUpdated === 1 }, + AssertOnQuery { _.stateOperatorProgresses.head.numRowsTotal === 2 }, + AddData(inputData, 25), + CheckAnswer((10, 3)), // watermark = 15 + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics("numOutputRows").value === 1 }, + AssertOnQuery { _.stateOperatorProgresses.head.numRowsUpdated === 1 }, + AssertOnQuery { _.stateOperatorProgresses.head.numRowsTotal === 2 } + ) + } + + testWithAllStateVersions("state metrics - update/complete mode") { val inputData = MemoryStream[Int] val aggregated =