diff --git a/docs/structured-streaming-state-data-source.md b/docs/structured-streaming-state-data-source.md
index ae323f6b0c14..986699130669 100644
--- a/docs/structured-streaming-state-data-source.md
+++ b/docs/structured-streaming-state-data-source.md
@@ -96,9 +96,9 @@ Each row in the source has the following schema:
|
- | _partition_id |
+ partition_id |
int |
- metadata column (hidden unless specified with SELECT) |
+ |
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
index 1192accaabef..1a8f444042c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DI
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.sources.DataSourceRegister
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
@@ -83,6 +83,7 @@ class StateDataSource extends TableProvider with DataSourceRegister {
new StructType()
.add("key", keySchema)
.add("value", valueSchema)
+ .add("partition_id", IntegerType)
} catch {
case NonFatal(e) =>
throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index 1e5f7216e8bf..ef8d7bf628bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.state
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
@@ -99,18 +99,7 @@ class StatePartitionReader(
}
}
- private val joinedRow = new JoinedRow()
-
- private def addMetadata(row: InternalRow): InternalRow = {
- val metadataRow = new GenericInternalRow(
- StateTable.METADATA_COLUMNS.map(_.name()).map {
- case "_partition_id" => partition.partition.asInstanceOf[Any]
- }.toArray
- )
- joinedRow.withLeft(row).withRight(metadataRow)
- }
-
- override def get(): InternalRow = addMetadata(current)
+ override def get(): InternalRow = current
override def close(): Unit = {
current = null
@@ -118,9 +107,10 @@ class StatePartitionReader(
}
private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = {
- val row = new GenericInternalRow(2)
+ val row = new GenericInternalRow(3)
row.update(0, pair._1)
row.update(1, pair._2)
+ row.update(2, partition.partition)
row
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
index 96c1c01cede2..824968e709ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
-import org.apache.spark.sql.types.{DataType, IntegerType, StructType}
+import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._
@@ -69,18 +69,20 @@ class StateTable(
override def properties(): util.Map[String, String] = Map.empty[String, String].asJava
private def isValidSchema(schema: StructType): Boolean = {
- if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value")) {
+ if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", "partition_id")) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) {
false
+ } else if (!SchemaUtil.getSchemaAsDataType(schema, "partition_id").isInstanceOf[IntegerType]) {
+ false
} else {
true
}
}
- override def metadataColumns(): Array[MetadataColumn] = METADATA_COLUMNS.toArray
+ override def metadataColumns(): Array[MetadataColumn] = Array.empty
}
/**
@@ -89,18 +91,4 @@ class StateTable(
*/
object StateTable {
private val CAPABILITY = Set(TableCapability.BATCH_READ).asJava
-
- val METADATA_COLUMNS: Seq[MetadataColumn] = Seq(PartitionId)
-
- private object PartitionId extends MetadataColumn {
- override def name(): String = "_partition_id"
-
- override def dataType(): DataType = IntegerType
-
- override def isNullable: Boolean = false
-
- override def comment(): String = {
- "Represents an ID for a physical state partition this row belongs to."
- }
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala
index 26492f8790c4..d0dd6cb7d1b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.state
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, JoinedRow, Literal, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, Literal, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
@@ -148,18 +148,7 @@ class StreamStreamJoinStatePartitionReader(
}
}
- private val joinedRow = new JoinedRow()
-
- private def addMetadata(row: InternalRow): InternalRow = {
- val metadataRow = new GenericInternalRow(
- StateTable.METADATA_COLUMNS.map(_.name()).map {
- case "_partition_id" => partition.partition.asInstanceOf[Any]
- }.toArray
- )
- joinedRow.withLeft(row).withRight(metadataRow)
- }
-
- override def get(): InternalRow = addMetadata(current)
+ override def get(): InternalRow = current
override def close(): Unit = {
current = null
@@ -169,9 +158,10 @@ class StreamStreamJoinStatePartitionReader(
}
private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = {
- val row = new GenericInternalRow(2)
+ val row = new GenericInternalRow(3)
row.update(0, pair._1)
row.update(1, pair._2)
+ row.update(2, partition.partition)
row
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
index 86c3ab70af68..c800168b507a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
@@ -687,7 +687,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
}
}
- test("metadata column") {
+ test("partition_id column") {
withTempDir { tempDir =>
import testImplicits._
val stream = MemoryStream[Int]
@@ -712,14 +712,11 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
// skip version and operator ID to test out functionalities
.load()
- assert(!stateReadDf.schema.exists(_.name == "_partition_id"),
- "metadata column should not be exposed until it is explicitly specified!")
-
val numShufflePartitions = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS)
val resultDf = stateReadDf
- .selectExpr("key.value AS key_value", "value.count AS value_count", "_partition_id")
- .where("_partition_id % 2 = 0")
+ .selectExpr("key.value AS key_value", "value.count AS value_count", "partition_id")
+ .where("partition_id % 2 = 0")
// NOTE: This is a hash function of distribution for stateful operator.
val hash = HashPartitioning(
@@ -738,17 +735,12 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
}
}
- test("metadata column with stream-stream join") {
+ test("partition_id column with stream-stream join") {
val numShufflePartitions = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS)
withTempDir { tempDir =>
runStreamStreamJoinQueryWithOneThousandInputs(tempDir.getAbsolutePath)
- def assertPartitionIdColumnIsNotExposedByDefault(df: DataFrame): Unit = {
- assert(!df.schema.exists(_.name == "_partition_id"),
- "metadata column should not be exposed until it is explicitly specified!")
- }
-
def assertPartitionIdColumn(df: DataFrame): Unit = {
// NOTE: This is a hash function of distribution for stateful operator.
// stream-stream join uses the grouping key for the equality match in the join condition.
@@ -759,8 +751,8 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
numShufflePartitions)
val partIdExpr = hash.partitionIdExpression
- val dfWithPartition = df.selectExpr("key.field0 As key_0", "_partition_id")
- .where("_partition_id % 2 = 0")
+ val dfWithPartition = df.selectExpr("key.field0 As key_0", "partition_id")
+ .where("partition_id % 2 = 0")
checkAnswer(dfWithPartition,
Range.inclusive(2, 1000, 2).map { idx =>
@@ -778,8 +770,6 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
.option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
.option(StateSourceOptions.JOIN_SIDE, side)
.load()
-
- assertPartitionIdColumnIsNotExposedByDefault(stateReaderForLeft)
assertPartitionIdColumn(stateReaderForLeft)
val stateReaderForKeyToNumValues = spark.read
@@ -789,7 +779,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
s"$side-keyToNumValues")
.load()
- assertPartitionIdColumnIsNotExposedByDefault(stateReaderForKeyToNumValues)
+
assertPartitionIdColumn(stateReaderForKeyToNumValues)
val stateReaderForKeyWithIndexToValue = spark.read
@@ -799,7 +789,6 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
s"$side-keyWithIndexToValue")
.load()
- assertPartitionIdColumnIsNotExposedByDefault(stateReaderForKeyWithIndexToValue)
assertPartitionIdColumn(stateReaderForKeyWithIndexToValue)
}