Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ case class FlatMapGroupsWithStateExec(
outputIterator,
{
store.commit()
longMetric("numTotalStateRows") += store.numKeys()
setStoreMetrics(store)
}
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,10 @@ trait ProgressReporter extends Logging {
if (lastExecution == null) return Nil
// lastExecution could belong to one of the previous triggers if `!hasNewData`.
// Walking the plan again should be inexpensive.
val stateNodes = lastExecution.executedPlan.collect {
case p if p.isInstanceOf[StateStoreWriter] => p
}
stateNodes.map { node =>
val numRowsUpdated = if (hasNewData) {
node.metrics.get("numUpdatedStateRows").map(_.value).getOrElse(0L)
} else {
0L
}
new StateOperatorProgress(
numRowsTotal = node.metrics.get("numTotalStateRows").map(_.value).getOrElse(0L),
numRowsUpdated = numRowsUpdated)
lastExecution.executedPlan.collect {
case p if p.isInstanceOf[StateStoreWriter] =>
val progress = p.asInstanceOf[StateStoreWriter].getProgress()
if (hasNewData) progress else progress.copy(newNumRowsUpdated = 0)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.io.LZ4CompressionCodec
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
import org.apache.spark.util.{SizeEstimator, Utils}


/**
Expand Down Expand Up @@ -172,7 +172,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
}
}

override def numKeys(): Long = mapToUpdate.size()
override def metrics: StateStoreMetrics =
StateStoreMetrics(mapToUpdate.size(), SizeEstimator.estimate(mapToUpdate))

/**
* Whether all updates have been committed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ trait StateStore {

def iterator(): Iterator[UnsafeRowPair]

/** Number of keys in the state store */
def numKeys(): Long
/** Current metrics of the state store */
def metrics: StateStoreMetrics

/**
* Whether all updates have been committed
Expand All @@ -104,6 +104,8 @@ trait StateStore {
}


case class StateStoreMetrics(val numKeys: Long, val memoryUsedBytes: Long)

/**
* Trait representing a provider that provide [[StateStore]] instances representing
* versions of state data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress}
import org.apache.spark.sql.types._
import org.apache.spark.util.{CompletionIterator, NextIterator}

Expand Down Expand Up @@ -73,16 +73,39 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
"numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"),
"allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to update rows"),
"allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to remove rows"),
"commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes")
"commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes"),
"stateMemory" -> SQLMetrics.createSizeMetric(sparkContext, "memory used by state")
)

/**
* Get the progress made by this stateful operator after execution. This should be called in
* the driver after this SparkPlan has been executed and metrics have been updated.
*/
def getProgress(): StateOperatorProgress = {
new StateOperatorProgress(
numRowsTotal = longMetric("numTotalStateRows").value,
numRowsUpdated = longMetric("numUpdatedStateRows").value,
memoryUsedBytes = longMetric("stateMemory").value,
numPartitions = this.sqlContext.conf.numShufflePartitions)
}

/** Records the duration of running `body` for the next query progress update. */
protected def timeTakenMs(body: => Unit): Long = {
val startTime = System.nanoTime()
val result = body
val endTime = System.nanoTime()
math.max(NANOSECONDS.toMillis(endTime - startTime), 0)
}

/**
* Set the SQL metrics related to the state store.
* This should be called in that task after the store has been updated.
*/
protected def setStoreMetrics(store: StateStore): Unit = {
val storeMetrics = store.metrics
longMetric("numTotalStateRows") += storeMetrics.numKeys
longMetric("stateMemory") += storeMetrics.memoryUsedBytes
}
}

/** An operator that supports watermark. */
Expand Down Expand Up @@ -197,7 +220,6 @@ case class StateStoreSaveExec(
Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
val numOutputRows = longMetric("numOutputRows")
val numTotalStateRows = longMetric("numTotalStateRows")
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
Expand All @@ -218,7 +240,7 @@ case class StateStoreSaveExec(
commitTimeMs += timeTakenMs {
store.commit()
}
numTotalStateRows += store.numKeys()
setStoreMetrics(store)
store.iterator().map { rowPair =>
numOutputRows += 1
rowPair.value
Expand Down Expand Up @@ -261,7 +283,7 @@ case class StateStoreSaveExec(
override protected def close(): Unit = {
allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs)
commitTimeMs += timeTakenMs { store.commit() }
numTotalStateRows += store.numKeys()
setStoreMetrics(store)
}
}

Expand All @@ -285,7 +307,7 @@ case class StateStoreSaveExec(
// Remove old aggregates if watermark specified
allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
commitTimeMs += timeTakenMs { store.commit() }
numTotalStateRows += store.numKeys()
setStoreMetrics(store)
false
} else {
true
Expand Down Expand Up @@ -368,7 +390,7 @@ case class StreamingDeduplicateExec(
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
commitTimeMs += timeTakenMs { store.commit() }
numTotalStateRows += store.numKeys()
setStoreMetrics(store)
})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,25 @@ import org.apache.spark.annotation.InterfaceStability
@InterfaceStability.Evolving
class StateOperatorProgress private[sql](
val numRowsTotal: Long,
val numRowsUpdated: Long) extends Serializable {
val numRowsUpdated: Long,
val memoryUsedBytes: Long,
val numPartitions: Long
) extends Serializable {

/** The compact JSON representation of this progress. */
def json: String = compact(render(jsonValue))

/** The pretty (i.e. indented) JSON representation of this progress. */
def prettyJson: String = pretty(render(jsonValue))

private[sql] def copy(newNumRowsUpdated: Long): StateOperatorProgress =
new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes, numPartitions)

private[sql] def jsonValue: JValue = {
("numRowsTotal" -> JInt(numRowsTotal)) ~
("numRowsUpdated" -> JInt(numRowsUpdated))
("numRowsUpdated" -> JInt(numRowsUpdated)) ~
("memoryUsedBytes" -> JInt(memoryUsedBytes)) ~
("numPartitions" -> JInt(numPartitions))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,15 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
}
}

test("reports metrics") {
val provider = newStoreProvider()
val store = provider.getStore(0)
val noDataMemoryUsed = store.metrics.memoryUsedBytes
put(store, "a", 1)
store.commit()
assert(store.metrics.memoryUsedBytes > noDataMemoryUsed)
}

test("StateStore.get") {
quietly {
val dir = newDir()
Expand Down Expand Up @@ -554,22 +563,22 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
assert(!store.hasCommitted)
assert(get(store, "a") === None)
assert(store.iterator().isEmpty)
assert(store.numKeys() === 0)
assert(store.metrics.numKeys === 0)

// Verify state after updating
put(store, "a", 1)
assert(get(store, "a") === Some(1))
assert(store.numKeys() === 1)
assert(store.metrics.numKeys === 1)

assert(store.iterator().nonEmpty)
assert(getLatestData(provider).isEmpty)

// Make updates, commit and then verify state
put(store, "b", 2)
put(store, "aa", 3)
assert(store.numKeys() === 3)
assert(store.metrics.numKeys === 3)
remove(store, _.startsWith("a"))
assert(store.numKeys() === 1)
assert(store.metrics.numKeys === 1)
assert(store.commit() === 1)

assert(store.hasCommitted)
Expand All @@ -587,9 +596,9 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
// New updates to the reloaded store with new version, and does not change old version
val reloadedProvider = newStoreProvider(store.id)
val reloadedStore = reloadedProvider.getStore(1)
assert(reloadedStore.numKeys() === 1)
assert(reloadedStore.metrics.numKeys === 1)
put(reloadedStore, "c", 4)
assert(reloadedStore.numKeys() === 2)
assert(reloadedStore.metrics.numKeys === 2)
assert(reloadedStore.commit() === 2)
assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.RDDScanExec
import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream}
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, UnsafeRowPair}
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.types.{DataType, IntegerType}
Expand Down Expand Up @@ -1077,7 +1077,7 @@ object FlatMapGroupsWithStateSuite {
override def abort(): Unit = { }
override def id: StateStoreId = null
override def version: Long = 0
override def numKeys(): Long = map.size
override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0)
override def hasCommitted: Boolean = true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,10 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite._

class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
implicit class EqualsIgnoreCRLF(source: String) {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaces this with assertJson (see below) because this code made it harded to debug the differences between json strings. With assertJson, scala test would show nice diffs.

[info] - StreamingQueryProgress - prettyJson *** FAILED *** (137 milliseconds)
[info]   "..."numRowsUpdated" : 1[,]
[info]       "memoryUsedByte..." did not equal "..."numRowsUpdated" : 1[]
[info]       "memoryUsedByte..." (StreamingQueryStatusAndProgressSuite.scala:213)
[info]   org.scalatest.exceptions.TestFailedException:

def equalsIgnoreCRLF(target: String): Boolean = {
source.replaceAll("\r\n|\r|\n", System.lineSeparator) ===
target.replaceAll("\r\n|\r|\n", System.lineSeparator)
}
}

test("StreamingQueryProgress - prettyJson") {
val json1 = testProgress1.prettyJson
assert(json1.equalsIgnoreCRLF(
assertJson(
json1,
s"""
|{
| "id" : "${testProgress1.id.toString}",
Expand All @@ -62,7 +56,9 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
| },
| "stateOperators" : [ {
| "numRowsTotal" : 0,
| "numRowsUpdated" : 1
| "numRowsUpdated" : 1,
| "memoryUsedBytes" : 2,
| "numPartitions" : 4
| } ],
| "sources" : [ {
| "description" : "source",
Expand All @@ -75,13 +71,13 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
| "description" : "sink"
| }
|}
""".stripMargin.trim))
""".stripMargin.trim)
assert(compact(parse(json1)) === testProgress1.json)

val json2 = testProgress2.prettyJson
assert(
json2.equalsIgnoreCRLF(
s"""
assertJson(
json2,
s"""
|{
| "id" : "${testProgress2.id.toString}",
| "runId" : "${testProgress2.runId.toString}",
Expand All @@ -93,7 +89,9 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
| },
| "stateOperators" : [ {
| "numRowsTotal" : 0,
| "numRowsUpdated" : 1
| "numRowsUpdated" : 1,
| "memoryUsedBytes" : 2,
| "numPartitions" : 4
| } ],
| "sources" : [ {
| "description" : "source",
Expand All @@ -105,7 +103,7 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
| "description" : "sink"
| }
|}
""".stripMargin.trim))
""".stripMargin.trim)
assert(compact(parse(json2)) === testProgress2.json)
}

Expand All @@ -121,14 +119,15 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {

test("StreamingQueryStatus - prettyJson") {
val json = testStatus.prettyJson
assert(json.equalsIgnoreCRLF(
assertJson(
json,
"""
|{
| "message" : "active",
| "isDataAvailable" : true,
| "isTriggerActive" : false
|}
""".stripMargin.trim))
""".stripMargin.trim)
}

test("StreamingQueryStatus - json") {
Expand Down Expand Up @@ -209,6 +208,12 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
}
}
}

def assertJson(source: String, expected: String): Unit = {
assert(
source.replaceAll("\r\n|\r|\n", System.lineSeparator) ===
expected.replaceAll("\r\n|\r|\n", System.lineSeparator))
}
}

object StreamingQueryStatusAndProgressSuite {
Expand All @@ -224,7 +229,8 @@ object StreamingQueryStatusAndProgressSuite {
"min" -> "2016-12-05T20:54:20.827Z",
"avg" -> "2016-12-05T20:54:20.827Z",
"watermark" -> "2016-12-05T20:54:20.827Z").asJava),
stateOperators = Array(new StateOperatorProgress(numRowsTotal = 0, numRowsUpdated = 1)),
stateOperators = Array(new StateOperatorProgress(
numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2, numPartitions = 4)),
sources = Array(
new SourceProgress(
description = "source",
Expand All @@ -247,7 +253,8 @@ object StreamingQueryStatusAndProgressSuite {
durationMs = new java.util.HashMap(Map("total" -> 0L).mapValues(long2Long).asJava),
// empty maps should be handled correctly
eventTime = new java.util.HashMap(Map.empty[String, String].asJava),
stateOperators = Array(new StateOperatorProgress(numRowsTotal = 0, numRowsUpdated = 1)),
stateOperators = Array(new StateOperatorProgress(
numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2, numPartitions = 4)),
sources = Array(
new SourceProgress(
description = "source",
Expand Down