Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/structured-streaming-state-data-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ Each row in the source has the following schema:
<td></td>
</tr>
<tr>
<td>_partition_id</td>
<td>partition_id</td>
<td>int</td>
<td>metadata column (hidden unless specified with SELECT)</td>
<td></td>
</tr>
</table>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DI
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
Expand Down Expand Up @@ -83,6 +83,7 @@ class StateDataSource extends TableProvider with DataSourceRegister {
new StructType()
.add("key", keySchema)
.add("value", valueSchema)
.add("partition_id", IntegerType)
} catch {
case NonFatal(e) =>
throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.state

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
Expand Down Expand Up @@ -99,28 +99,18 @@ class StatePartitionReader(
}
}

private val joinedRow = new JoinedRow()

private def addMetadata(row: InternalRow): InternalRow = {
val metadataRow = new GenericInternalRow(
StateTable.METADATA_COLUMNS.map(_.name()).map {
case "_partition_id" => partition.partition.asInstanceOf[Any]
}.toArray
)
joinedRow.withLeft(row).withRight(metadataRow)
}

override def get(): InternalRow = addMetadata(current)
override def get(): InternalRow = current

override def close(): Unit = {
current = null
store.abort()
}

private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = {
val row = new GenericInternalRow(2)
val row = new GenericInternalRow(3)
row.update(0, pair._1)
row.update(1, pair._2)
row.update(2, partition.partition)
row
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
import org.apache.spark.sql.types.{DataType, IntegerType, StructType}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -69,18 +69,20 @@ class StateTable(
override def properties(): util.Map[String, String] = Map.empty[String, String].asJava

private def isValidSchema(schema: StructType): Boolean = {
if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value")) {
if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", "partition_id")) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "partition_id").isInstanceOf[IntegerType]) {
false
} else {
true
}
}

override def metadataColumns(): Array[MetadataColumn] = METADATA_COLUMNS.toArray
override def metadataColumns(): Array[MetadataColumn] = Array.empty
}

/**
Expand All @@ -89,18 +91,4 @@ class StateTable(
*/
object StateTable {
private val CAPABILITY = Set(TableCapability.BATCH_READ).asJava

val METADATA_COLUMNS: Seq[MetadataColumn] = Seq(PartitionId)

private object PartitionId extends MetadataColumn {
override def name(): String = "_partition_id"

override def dataType(): DataType = IntegerType

override def isNullable: Boolean = false

override def comment(): String = {
"Represents an ID for a physical state partition this row belongs to."
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.state

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, JoinedRow, Literal, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, Literal, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
Expand Down Expand Up @@ -148,18 +148,7 @@ class StreamStreamJoinStatePartitionReader(
}
}

private val joinedRow = new JoinedRow()

private def addMetadata(row: InternalRow): InternalRow = {
val metadataRow = new GenericInternalRow(
StateTable.METADATA_COLUMNS.map(_.name()).map {
case "_partition_id" => partition.partition.asInstanceOf[Any]
}.toArray
)
joinedRow.withLeft(row).withRight(metadataRow)
}

override def get(): InternalRow = addMetadata(current)
override def get(): InternalRow = current

override def close(): Unit = {
current = null
Expand All @@ -169,9 +158,10 @@ class StreamStreamJoinStatePartitionReader(
}

private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = {
val row = new GenericInternalRow(2)
val row = new GenericInternalRow(3)
row.update(0, pair._1)
row.update(1, pair._2)
row.update(2, partition.partition)
row
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
}
}

test("metadata column") {
test("partition_id column") {
withTempDir { tempDir =>
import testImplicits._
val stream = MemoryStream[Int]
Expand All @@ -712,14 +712,11 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
// skip version and operator ID to test out functionalities
.load()

assert(!stateReadDf.schema.exists(_.name == "_partition_id"),
"metadata column should not be exposed until it is explicitly specified!")

val numShufflePartitions = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS)

val resultDf = stateReadDf
.selectExpr("key.value AS key_value", "value.count AS value_count", "_partition_id")
.where("_partition_id % 2 = 0")
.selectExpr("key.value AS key_value", "value.count AS value_count", "partition_id")
.where("partition_id % 2 = 0")

// NOTE: This is a hash function of distribution for stateful operator.
val hash = HashPartitioning(
Expand All @@ -738,17 +735,12 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
}
}

test("metadata column with stream-stream join") {
test("partition_id column with stream-stream join") {
val numShufflePartitions = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS)

withTempDir { tempDir =>
runStreamStreamJoinQueryWithOneThousandInputs(tempDir.getAbsolutePath)

def assertPartitionIdColumnIsNotExposedByDefault(df: DataFrame): Unit = {
assert(!df.schema.exists(_.name == "_partition_id"),
"metadata column should not be exposed until it is explicitly specified!")
}

def assertPartitionIdColumn(df: DataFrame): Unit = {
// NOTE: This is a hash function of distribution for stateful operator.
// stream-stream join uses the grouping key for the equality match in the join condition.
Expand All @@ -759,8 +751,8 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
numShufflePartitions)
val partIdExpr = hash.partitionIdExpression

val dfWithPartition = df.selectExpr("key.field0 As key_0", "_partition_id")
.where("_partition_id % 2 = 0")
val dfWithPartition = df.selectExpr("key.field0 As key_0", "partition_id")
.where("partition_id % 2 = 0")

checkAnswer(dfWithPartition,
Range.inclusive(2, 1000, 2).map { idx =>
Expand All @@ -778,8 +770,6 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
.option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
.option(StateSourceOptions.JOIN_SIDE, side)
.load()

assertPartitionIdColumnIsNotExposedByDefault(stateReaderForLeft)
assertPartitionIdColumn(stateReaderForLeft)

val stateReaderForKeyToNumValues = spark.read
Expand All @@ -789,7 +779,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
s"$side-keyToNumValues")
.load()

assertPartitionIdColumnIsNotExposedByDefault(stateReaderForKeyToNumValues)

assertPartitionIdColumn(stateReaderForKeyToNumValues)

val stateReaderForKeyWithIndexToValue = spark.read
Expand All @@ -799,7 +789,6 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
s"$side-keyWithIndexToValue")
.load()

assertPartitionIdColumnIsNotExposedByDefault(stateReaderForKeyWithIndexToValue)
assertPartitionIdColumn(stateReaderForKeyWithIndexToValue)
}

Expand Down