Skip to content

Commit d230232

Browse files
[SPARK-49745][SS] Add change to read registered timers through state data source reader
### What changes were proposed in this pull request? Add change to read registered timers through state data source reader ### Why are the changes needed? Without this, users cannot read registered timers per grouping key within the transformWithState operator ### Does this PR introduce _any_ user-facing change? Yes Users can now read registered timers using the following query: ``` val stateReaderDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, <checkpoint_loc>) .option(StateSourceOptions.READ_REGISTERED_TIMERS, true) .load() ``` ### How was this patch tested? Added unit tests ``` [info] Run completed in 20 seconds, 834 milliseconds. [info] Total number of tests run: 4 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 4, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48205 from anishshri-db/task/SPARK-49745. Lead-authored-by: Anish Shrigondekar <[email protected]> Co-authored-by: Jungtaek Lim <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent 0ccf53a commit d230232

File tree

11 files changed

+263
-20
lines changed

11 files changed

+263
-20
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,16 @@ import org.apache.spark.sql.SparkSession
2929
import org.apache.spark.sql.catalyst.DataSourceOptions
3030
import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
3131
import org.apache.spark.sql.connector.expressions.Transform
32-
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, STATE_VAR_NAME}
32+
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, READ_REGISTERED_TIMERS, STATE_VAR_NAME}
3333
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues
3434
import org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader, StateMetadataTableEntry}
3535
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
36-
import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata, TransformWithStateOperatorProperties, TransformWithStateVariableInfo}
36+
import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata, TimerStateUtils, TransformWithStateOperatorProperties, TransformWithStateVariableInfo}
3737
import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE}
3838
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
3939
import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId}
4040
import org.apache.spark.sql.sources.DataSourceRegister
41+
import org.apache.spark.sql.streaming.TimeMode
4142
import org.apache.spark.sql.types.StructType
4243
import org.apache.spark.sql.util.CaseInsensitiveStringMap
4344
import org.apache.spark.util.SerializableConfiguration
@@ -132,7 +133,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
132133
sourceOptions: StateSourceOptions,
133134
stateStoreMetadata: Array[StateMetadataTableEntry]): Unit = {
134135
val twsShortName = "transformWithStateExec"
135-
if (sourceOptions.stateVarName.isDefined) {
136+
if (sourceOptions.stateVarName.isDefined || sourceOptions.readRegisteredTimers) {
136137
// Perform checks for transformWithState operator in case state variable name is provided
137138
require(stateStoreMetadata.size == 1)
138139
val opMetadata = stateStoreMetadata.head
@@ -153,10 +154,21 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
153154
"No state variable names are defined for the transformWithState operator")
154155
}
155156

157+
val twsOperatorProperties = TransformWithStateOperatorProperties.fromJson(operatorProperties)
158+
val timeMode = twsOperatorProperties.timeMode
159+
if (sourceOptions.readRegisteredTimers && timeMode == TimeMode.None().toString) {
160+
throw StateDataSourceErrors.invalidOptionValue(READ_REGISTERED_TIMERS,
161+
"Registered timers are not available in TimeMode=None.")
162+
}
163+
156164
// if the state variable is not one of the defined/available state variables, then we
157165
// fail the query
158-
val stateVarName = sourceOptions.stateVarName.get
159-
val twsOperatorProperties = TransformWithStateOperatorProperties.fromJson(operatorProperties)
166+
val stateVarName = if (sourceOptions.readRegisteredTimers) {
167+
TimerStateUtils.getTimerStateVarName(timeMode)
168+
} else {
169+
sourceOptions.stateVarName.get
170+
}
171+
160172
val stateVars = twsOperatorProperties.stateVariables
161173
if (stateVars.filter(stateVar => stateVar.stateName == stateVarName).size != 1) {
162174
throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
@@ -196,9 +208,10 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
196208
var keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec] = None
197209
var stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema] = None
198210
var transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo] = None
211+
var timeMode: String = TimeMode.None.toString
199212

200213
if (sourceOptions.joinSide == JoinSideValues.none) {
201-
val stateVarName = sourceOptions.stateVarName
214+
var stateVarName = sourceOptions.stateVarName
202215
.getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)
203216

204217
// Read the schema file path from operator metadata version v2 onwards
@@ -208,6 +221,12 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
208221
val storeMetadataEntry = storeMetadata.head
209222
val operatorProperties = TransformWithStateOperatorProperties.fromJson(
210223
storeMetadataEntry.operatorPropertiesJson)
224+
timeMode = operatorProperties.timeMode
225+
226+
if (sourceOptions.readRegisteredTimers) {
227+
stateVarName = TimerStateUtils.getTimerStateVarName(timeMode)
228+
}
229+
211230
val stateVarInfoList = operatorProperties.stateVariables
212231
.filter(stateVar => stateVar.stateName == stateVarName)
213232
require(stateVarInfoList.size == 1, s"Failed to find unique state variable info " +
@@ -304,6 +323,7 @@ case class StateSourceOptions(
304323
fromSnapshotOptions: Option[FromSnapshotOptions],
305324
readChangeFeedOptions: Option[ReadChangeFeedOptions],
306325
stateVarName: Option[String],
326+
readRegisteredTimers: Boolean,
307327
flattenCollectionTypes: Boolean) {
308328
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)
309329

@@ -336,6 +356,7 @@ object StateSourceOptions extends DataSourceOptions {
336356
val CHANGE_START_BATCH_ID = newOption("changeStartBatchId")
337357
val CHANGE_END_BATCH_ID = newOption("changeEndBatchId")
338358
val STATE_VAR_NAME = newOption("stateVarName")
359+
val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers")
339360
val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes")
340361

341362
object JoinSideValues extends Enumeration {
@@ -377,6 +398,19 @@ object StateSourceOptions extends DataSourceOptions {
377398
val stateVarName = Option(options.get(STATE_VAR_NAME))
378399
.map(_.trim)
379400

401+
val readRegisteredTimers = try {
402+
Option(options.get(READ_REGISTERED_TIMERS))
403+
.map(_.toBoolean).getOrElse(false)
404+
} catch {
405+
case _: IllegalArgumentException =>
406+
throw StateDataSourceErrors.invalidOptionValue(READ_REGISTERED_TIMERS,
407+
"Boolean value is expected")
408+
}
409+
410+
if (readRegisteredTimers && stateVarName.isDefined) {
411+
throw StateDataSourceErrors.conflictOptions(Seq(READ_REGISTERED_TIMERS, STATE_VAR_NAME))
412+
}
413+
380414
val flattenCollectionTypes = try {
381415
Option(options.get(FLATTEN_COLLECTION_TYPES))
382416
.map(_.toBoolean).getOrElse(true)
@@ -489,8 +523,8 @@ object StateSourceOptions extends DataSourceOptions {
489523

490524
StateSourceOptions(
491525
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
492-
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, stateVarName,
493-
flattenCollectionTypes)
526+
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
527+
stateVarName, readRegisteredTimers, flattenCollectionTypes)
494528
}
495529

496530
private def resolvedCheckpointLocation(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ abstract class StatePartitionReaderBase(
107107
useColumnFamilies = useColFamilies, storeConf, hadoopConf.value,
108108
useMultipleValuesPerKey = useMultipleValuesPerKey)
109109

110+
val isInternal = partition.sourceOptions.readRegisteredTimers
111+
110112
if (useColFamilies) {
111113
val store = provider.getStore(partition.sourceOptions.batchId + 1)
112114
require(stateStoreColFamilySchemaOpt.isDefined)
@@ -117,7 +119,8 @@ abstract class StatePartitionReaderBase(
117119
stateStoreColFamilySchema.keySchema,
118120
stateStoreColFamilySchema.valueSchema,
119121
stateStoreColFamilySchema.keyStateEncoderSpec.get,
120-
useMultipleValuesPerKey = useMultipleValuesPerKey)
122+
useMultipleValuesPerKey = useMultipleValuesPerKey,
123+
isInternal = isInternal)
121124
}
122125
provider
123126
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ object SchemaUtil {
230230
"map_value" -> classOf[MapType],
231231
"user_map_key" -> classOf[StructType],
232232
"user_map_value" -> classOf[StructType],
233+
"expiration_timestamp_ms" -> classOf[LongType],
233234
"partition_id" -> classOf[IntegerType])
234235

235236
val expectedFieldNames = if (sourceOptions.readChangeFeed) {
@@ -256,6 +257,9 @@ object SchemaUtil {
256257
Seq("key", "map_value", "partition_id")
257258
}
258259

260+
case TimerState =>
261+
Seq("key", "expiration_timestamp_ms", "partition_id")
262+
259263
case _ =>
260264
throw StateDataSourceErrors
261265
.internalError(s"Unsupported state variable type $stateVarType")
@@ -322,6 +326,14 @@ object SchemaUtil {
322326
.add("partition_id", IntegerType)
323327
}
324328

329+
case TimerState =>
330+
val groupingKeySchema = SchemaUtil.getSchemaAsDataType(
331+
stateStoreColFamilySchema.keySchema, "key")
332+
new StructType()
333+
.add("key", groupingKeySchema)
334+
.add("expiration_timestamp_ms", LongType)
335+
.add("partition_id", IntegerType)
336+
325337
case _ =>
326338
throw StateDataSourceErrors.internalError(s"Unsupported state variable type $stateVarType")
327339
}
@@ -407,9 +419,30 @@ object SchemaUtil {
407419
unifyMapStateRowPair(store.iterator(stateVarName),
408420
compositeKeySchema, partitionId, stateSourceOptions)
409421

422+
case StateVariableType.TimerState =>
423+
store
424+
.iterator(stateVarName)
425+
.map { pair =>
426+
unifyTimerRow(pair.key, compositeKeySchema, partitionId)
427+
}
428+
410429
case _ =>
411430
throw new IllegalStateException(
412431
s"Unsupported state variable type: $stateVarType")
413432
}
414433
}
434+
435+
private def unifyTimerRow(
436+
rowKey: UnsafeRow,
437+
groupingKeySchema: StructType,
438+
partitionId: Int): InternalRow = {
439+
val groupingKey = rowKey.get(0, groupingKeySchema).asInstanceOf[UnsafeRow]
440+
val expirationTimestamp = rowKey.getLong(1)
441+
442+
val row = new GenericInternalRow(3)
443+
row.update(0, groupingKey)
444+
row.update(1, expirationTimestamp)
445+
row.update(2, partitionId)
446+
row
447+
}
415448
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import org.apache.spark.sql.Encoder
2020
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2121
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
2222
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema}
23+
import org.apache.spark.sql.types.StructType
2324

2425
object StateStoreColumnFamilySchemaUtils {
2526

@@ -61,4 +62,15 @@ object StateStoreColumnFamilySchemaUtils {
6162
Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)),
6263
Some(userKeyEnc.schema))
6364
}
65+
66+
def getTimerStateSchema(
67+
stateName: String,
68+
keySchema: StructType,
69+
valSchema: StructType): StateStoreColFamilySchema = {
70+
StateStoreColFamilySchema(
71+
stateName,
72+
keySchema,
73+
valSchema,
74+
Some(PrefixKeyScanStateEncoderSpec(keySchema, 1)))
75+
}
6476
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,9 @@ class TimerKeyEncoder(keyExprEnc: ExpressionEncoder[Any]) {
288288
.add("key", new StructType(keyExprEnc.schema.fields))
289289
.add("expiryTimestampMs", LongType, nullable = false)
290290

291+
val schemaForValueRow: StructType =
292+
StructType(Array(StructField("__dummy__", NullType)))
293+
291294
private val keySerializer = keyExprEnc.createSerializer()
292295
private val keyDeserializer = keyExprEnc.resolveAndBind().createDeserializer()
293296
private val prefixKeyProjection = UnsafeProjection.create(schemaForPrefixKey)

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,12 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
308308
private val stateVariableInfos: mutable.Map[String, TransformWithStateVariableInfo] =
309309
new mutable.HashMap[String, TransformWithStateVariableInfo]()
310310

311+
// If timeMode is not None, add a timer column family schema to the operator metadata so that
312+
// registered timers can be read using the state data source reader.
313+
if (timeMode != TimeMode.None()) {
314+
addTimerColFamily()
315+
}
316+
311317
def getColumnFamilySchemas: Map[String, StateStoreColFamilySchema] = columnFamilySchemas.toMap
312318

313319
def getStateVariableInfos: Map[String, TransformWithStateVariableInfo] = stateVariableInfos.toMap
@@ -318,6 +324,16 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
318324
}
319325
}
320326

327+
private def addTimerColFamily(): Unit = {
328+
val stateName = TimerStateUtils.getTimerStateVarName(timeMode.toString)
329+
val timerEncoder = new TimerKeyEncoder(keyExprEnc)
330+
val colFamilySchema = StateStoreColumnFamilySchemaUtils.
331+
getTimerStateSchema(stateName, timerEncoder.schemaForKeyRow, timerEncoder.schemaForValueRow)
332+
columnFamilySchemas.put(stateName, colFamilySchema)
333+
val stateVariableInfo = TransformWithStateVariableUtils.getTimerState(stateName)
334+
stateVariableInfos.put(stateName, stateVariableInfo)
335+
}
336+
321337
override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = {
322338
verifyStateVarOperations("get_value_state", PRE_INIT)
323339
val colFamilySchema = StateStoreColumnFamilySchemaUtils.

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ object TimerStateUtils {
3434
val EVENT_TIMERS_STATE_NAME = "$eventTimers"
3535
val KEY_TO_TIMESTAMP_CF = "_keyToTimestamp"
3636
val TIMESTAMP_TO_KEY_CF = "_timestampToKey"
37+
38+
def getTimerStateVarName(timeMode: String): String = {
39+
assert(timeMode == TimeMode.EventTime.toString || timeMode == TimeMode.ProcessingTime.toString)
40+
if (timeMode == TimeMode.EventTime.toString) {
41+
TimerStateUtils.EVENT_TIMERS_STATE_NAME + TimerStateUtils.KEY_TO_TIMESTAMP_CF
42+
} else {
43+
TimerStateUtils.PROC_TIMERS_STATE_NAME + TimerStateUtils.KEY_TO_TIMESTAMP_CF
44+
}
45+
}
3746
}
3847

3948
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,16 @@ object TransformWithStateVariableUtils {
4343
def getMapState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = {
4444
TransformWithStateVariableInfo(stateName, StateVariableType.MapState, ttlEnabled)
4545
}
46+
47+
def getTimerState(stateName: String): TransformWithStateVariableInfo = {
48+
TransformWithStateVariableInfo(stateName, StateVariableType.TimerState, ttlEnabled = false)
49+
}
4650
}
4751

4852
// Enum of possible State Variable types
4953
object StateVariableType extends Enumeration {
5054
type StateVariableType = Value
51-
val ValueState, ListState, MapState = Value
55+
val ValueState, ListState, MapState, TimerState = Value
5256
}
5357

5458
case class TransformWithStateVariableInfo(

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,25 @@ class StateDataSourceNegativeTestSuite extends StateDataSourceTestBase {
288288
}
289289
}
290290

291+
test("ERROR: trying to specify state variable name along with " +
292+
"readRegisteredTimers should fail") {
293+
withTempDir { tempDir =>
294+
val exc = intercept[StateDataSourceConflictOptions] {
295+
spark.read.format("statestore")
296+
// trick to bypass getting the last committed batch before validating operator ID
297+
.option(StateSourceOptions.BATCH_ID, 0)
298+
.option(StateSourceOptions.STATE_VAR_NAME, "test")
299+
.option(StateSourceOptions.READ_REGISTERED_TIMERS, true)
300+
.load(tempDir.getAbsolutePath)
301+
}
302+
checkError(exc, "STDS_CONFLICT_OPTIONS", "42613",
303+
Map("options" ->
304+
s"['${
305+
StateSourceOptions.READ_REGISTERED_TIMERS
306+
}', '${StateSourceOptions.STATE_VAR_NAME}']"))
307+
}
308+
}
309+
291310
test("ERROR: trying to specify non boolean value for " +
292311
"flattenCollectionTypes") {
293312
withTempDir { tempDir =>

0 commit comments

Comments
 (0)