Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val STATE_STORE_PROVIDER_CLASS =
buildConf("spark.sql.streaming.stateStore.providerClass")
.internal()
.doc(
"The class used to manage state data in stateful streaming queries. This class must " +
"be a subclass of StateStoreProvider, and must have a zero-arg constructor.")
.stringConf
.createOptional

val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT =
buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot")
.internal()
Expand Down Expand Up @@ -828,6 +837,8 @@ class SQLConf extends Serializable with Logging {

def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD)

def stateStoreProviderClass: Option[String] = getConf(STATE_STORE_PROVIDER_CLASS)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add this to StateStoreConf for consistency?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.


def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)

def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,11 @@ case class FlatMapGroupsWithStateExec(
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateId.checkpointLocation,
getStateId.operatorId,
storeName = "default",
getStateId.batchId,
groupingAttributes.toStructType,
stateAttributes.toStructType,
indexOrdinal = None,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
val updater = new StateStoreUpdater(store)
Expand Down Expand Up @@ -191,12 +193,12 @@ case class FlatMapGroupsWithStateExec(
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
}
val timingOutKeys = store.filter { case (_, stateRow) =>
val timeoutTimestamp = getTimeoutTimestamp(stateRow)
val timingOutKeys = store.getRange(None, None).filter { rowPair =>
val timeoutTimestamp = getTimeoutTimestamp(rowPair.value)
timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold
}
timingOutKeys.flatMap { case (keyRow, stateRow) =>
callFunctionAndUpdateState(keyRow, Iterator.empty, Some(stateRow), hasTimedOut = true)
timingOutKeys.flatMap { rowPair =>
callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true)
}
} else Iterator.empty
}
Expand All @@ -205,18 +207,23 @@ case class FlatMapGroupsWithStateExec(
* Call the user function on a key's data, update the state store, and return the return data
* iterator. Note that the store updating is lazy, that is, the store will be updated only
* after the returned iterator is fully consumed.
*
* @param keyRow Row representing the key, cannot be null
* @param valueRowIter Iterator of values as rows, cannot be null, but can be empty
* @param prevStateRow Row representing the previous state, can be null
* @param hasTimedOut Whether this function is being called for a key timeout
*/
private def callFunctionAndUpdateState(
keyRow: UnsafeRow,
valueRowIter: Iterator[InternalRow],
prevStateRowOption: Option[UnsafeRow],
prevStateRow: UnsafeRow,
hasTimedOut: Boolean): Iterator[InternalRow] = {

val keyObj = getKeyObj(keyRow) // convert key to objects
val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
val stateObjOption = getStateObj(prevStateRowOption)
val stateObj = getStateObj(prevStateRow)
val keyedState = GroupStateImpl.createForStreaming(
stateObjOption,
Option(stateObj),
batchTimestampMs.getOrElse(NO_TIMESTAMP),
eventTimeWatermark.getOrElse(NO_TIMESTAMP),
timeoutConf,
Expand Down Expand Up @@ -249,14 +256,11 @@ case class FlatMapGroupsWithStateExec(
numUpdatedStateRows += 1

} else {
val previousTimeoutTimestamp = prevStateRowOption match {
case Some(row) => getTimeoutTimestamp(row)
case None => NO_TIMESTAMP
}
val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow)
val stateRowToWrite = if (keyedState.hasUpdated) {
getStateRow(keyedState.get)
} else {
prevStateRowOption.orNull
prevStateRow
}

val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp
Expand All @@ -269,7 +273,7 @@ case class FlatMapGroupsWithStateExec(
throw new IllegalStateException("Attempting to write empty state")
}
setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp)
store.put(keyRow.copy(), stateRowToWrite.copy())
store.put(keyRow, stateRowToWrite)
numUpdatedStateRows += 1
}
}
Expand All @@ -280,18 +284,21 @@ case class FlatMapGroupsWithStateExec(
}

/** Returns the state as Java object if defined */
def getStateObj(stateRowOption: Option[UnsafeRow]): Option[Any] = {
stateRowOption.map(getStateObjFromRow)
def getStateObj(stateRow: UnsafeRow): Any = {
if (stateRow != null) getStateObjFromRow(stateRow) else null
}

/** Returns the row for an updated state */
def getStateRow(obj: Any): UnsafeRow = {
assert(obj != null)
getStateRowFromObj(obj)
}

/** Returns the timeout timestamp of a state row is set */
def getTimeoutTimestamp(stateRow: UnsafeRow): Long = {
if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else NO_TIMESTAMP
if (isTimeoutEnabled && stateRow != null) {
stateRow.getLong(timeoutTimestampIndex)
} else NO_TIMESTAMP
}

/** Set the timestamp in a state row */
Expand Down
Loading