Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ org.apache.spark.sql.execution.datasources.noop.NoopDataSource
org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
org.apache.spark.sql.execution.datasources.v2.parquet.ParquetDataSourceV2
org.apache.spark.sql.execution.datasources.v2.text.TextDataSourceV2
org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceV2
Copy link
Member

@HyukjinKwon HyukjinKwon Apr 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I came from the PR you pointed out. Why is it state? Can batch query use this source?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"state" is the one of the "terms" of "structured streaming" (not actually tied to structured streaming but tied to recent streaming technology). It's being created and used from structured streaming, but there're some cases we want to modify the state "outside" of the streaming query, like changing schema, repartitioning, etc. This data source will allow "batch query" to do it. (So the data source is not even designed to use from streaming query by intention.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So, it's designed for batch query for the state generated from structured streaming.
@HeartSaVioR, could I ask to post a working example in PR or JIRA description? I think one working example will clarify what this source/PR targets.

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Apr 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would do when I get any actual reviewer who is willing to be a shepherd on this issue - the only request I got for this feature was asking for SPIP.

https://github.com/HeartSaVioR/spark-state-tools

Above repository contains entire functionalities (though it's tied to Spark 2.4 and some weird usage because Spark doesn't provide schema information) and explanation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking showing an example can actually clarify the importance of this source easily and hopefully we can get more review and attention. But okay, we can wait for the review first too.

org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.state

import org.apache.spark.sql.types.{DataType, StructType}

object SchemaUtil {
def getSchemaAsDataType(schema: StructType, fieldName: String): DataType = {
schema(schema.getFieldIndex(fieldName).get).dataType
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.state

import java.util

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.streaming.CommitLog
import org.apache.spark.sql.execution.streaming.state.{StateSchemaFileManager, StateStore, StateStoreId}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class StateDataSourceV2 extends TableProvider with DataSourceRegister {

import StateDataSourceV2._

lazy val session: SparkSession = SparkSession.active

override def shortName(): String = "state"

override def getTable(
schema: StructType,
partitioning: Array[Transform],
properties: util.Map[String, String]): Table = {
val checkpointLocation = Option(properties.get(PARAM_CHECKPOINT_LOCATION)).orElse {
throw new AnalysisException(s"'$PARAM_CHECKPOINT_LOCATION' must be specified.")
}.get

val version = Option(properties.get(PARAM_VERSION)).map(_.toLong).orElse {
Some(getLastCommittedBatch(checkpointLocation))
}.get

val operatorId = Option(properties.get(PARAM_OPERATOR_ID)).map(_.toInt)
.orElse(Some(0)).get

val storeName = Option(properties.get(PARAM_STORE_NAME))
.orElse(Some(StateStoreId.DEFAULT_STORE_NAME)).get

val stateCheckpointLocation = new Path(checkpointLocation, "state")
new StateTable(session, schema, stateCheckpointLocation.toString, version, operatorId,
storeName)
}

override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
val checkpointLocation = Option(options.get(PARAM_CHECKPOINT_LOCATION)).orElse {
throw new AnalysisException(s"'$PARAM_CHECKPOINT_LOCATION' must be specified.")
}.get

val operatorId = Option(options.get(PARAM_OPERATOR_ID)).map(_.toInt)
.orElse(Some(0)).get

val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
val storeName = Option(options.get(PARAM_STORE_NAME))
.orElse(Some(StateStoreId.DEFAULT_STORE_NAME)).get

val stateCheckpointLocation = new Path(checkpointLocation, "state")
val storeId = new StateStoreId(stateCheckpointLocation.toString, operatorId, partitionId,
storeName)
val manager = new StateSchemaFileManager(storeId, session.sessionState.newHadoopConf())
if (manager.fileExist()) {
val (keySchema, valueSchema) = manager.readSchema()
new StructType()
.add("key", keySchema)
.add("value", valueSchema)
} else {
throw new UnsupportedOperationException("Schema information file doesn't exist - schema " +
"should be explicitly specified.")
}
}

private def getLastCommittedBatch(checkpointLocation: String): Long = {
val commitLog = new CommitLog(session, new Path(checkpointLocation, "commits").toString)
val lastCommittedBatchId = commitLog.getLatest() match {
case Some((lastId, _)) => lastId
case None => -1
}

if (lastCommittedBatchId < 0) {
throw new AnalysisException("No committed batch found.")
}

lastCommittedBatchId.toLong + 1
}

override def supportsExternalMetadata(): Boolean = true
}

object StateDataSourceV2 {
val PARAM_CHECKPOINT_LOCATION = "checkpointLocation"
val PARAM_VERSION = "version"
val PARAM_OPERATOR_ID = "operatorId"
val PARAM_STORE_NAME = "storeName"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.state

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.PartitionReader
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

class StatePartitionReader(
storeConf: StateStoreConf,
hadoopConf: SerializableConfiguration,
partition: StateStoreInputPartition,
schema: StructType) extends PartitionReader[InternalRow] {

private val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType]

private lazy val iter = {
val stateStoreId = StateStoreId(partition.stateCheckpointRootLocation,
partition.operatorId, partition.partition, partition.storeName)
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)

val store = StateStore.getReadOnly(stateStoreProviderId, keySchema, valueSchema,
indexOrdinal = None, version = partition.version, storeConf = storeConf,
hadoopConf = hadoopConf.value)

store.iterator().map(pair => unifyStateRowPair((pair.key, pair.value)))
}

private var current: InternalRow = _

override def next(): Boolean = {
if (iter.hasNext) {
current = iter.next()
true
} else {
current = null
false
}
}

override def get(): InternalRow = current

override def close(): Unit = {
current = null
}

private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = {
val row = new GenericInternalRow(2)
row.update(0, pair._1)
row.update(1, pair._2)
row
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.state

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

class StatePartitionReaderFactory(
storeConf: StateStoreConf,
hadoopConf: SerializableConfiguration,
schema: StructType) extends PartitionReaderFactory {

override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val part = partition match {
case p: StateStoreInputPartition => p
case e => throw new IllegalStateException("Expected StateStorePartition but other type of " +
s"partition passed - $e")
}

new StatePartitionReader(storeConf, hadoopConf, part, schema)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.state

import java.util.UUID

import scala.util.Try

import org.apache.hadoop.fs.{Path, PathFilter}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

class StateScanBuilder(
session: SparkSession,
schema: StructType,
stateCheckpointRootLocation: String,
version: Long,
operatorId: Long,
storeName: String) extends ScanBuilder {
override def build(): Scan = new StateScan(session, schema, stateCheckpointRootLocation,
version, operatorId, storeName)
}

class StateStoreInputPartition(
val partition: Int,
val queryId: UUID,
val stateCheckpointRootLocation: String,
val version: Long,
val operatorId: Long,
val storeName: String) extends InputPartition

class StateScan(
session: SparkSession,
schema: StructType,
stateCheckpointRootLocation: String,
version: Long,
operatorId: Long,
storeName: String) extends Scan with Batch {

// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
private val hadoopConfBroadcast = session.sparkContext.broadcast(
new SerializableConfiguration(session.sessionState.newHadoopConf()))

private val resolvedCpLocation = {
val checkpointPath = new Path(stateCheckpointRootLocation)
val fs = checkpointPath.getFileSystem(session.sessionState.newHadoopConf())
checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString
}

override def readSchema(): StructType = schema

override def planInputPartitions(): Array[InputPartition] = {
val fs = stateCheckpointPartitionsLocation.getFileSystem(hadoopConfBroadcast.value.value)
val partitions = fs.listStatus(stateCheckpointPartitionsLocation, new PathFilter() {
override def accept(path: Path): Boolean = {
fs.isDirectory(path) && Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0
}
})

if (partitions.headOption.isEmpty) {
Array.empty[InputPartition]
} else {
// just a dummy query id because we are actually not running streaming query
val queryId = UUID.randomUUID()

val partitionsSorted = partitions.sortBy(fs => fs.getPath.getName.toInt)
val partitionNums = partitionsSorted.map(_.getPath.getName.toInt)
// assuming no same number - they're directories hence no same name
val head = partitionNums.head
val tail = partitionNums(partitionNums.length - 1)
assert(head == 0, "Partition should start with 0")
assert((tail - head + 1) == partitionNums.length,
s"No continuous partitions in state: $partitionNums")

partitionNums.map {
pn => new StateStoreInputPartition(pn, queryId, stateCheckpointRootLocation,
version, operatorId, storeName)
}.toArray
}
}

override def createReaderFactory(): PartitionReaderFactory =
new StatePartitionReaderFactory(new StateStoreConf(session.sessionState.conf),
hadoopConfBroadcast.value, schema)

override def toBatch: Batch = this

override def description(): String = s"StateScan [stateCpLocation=$stateCheckpointRootLocation]" +
s"[version=$version][operatorId=$operatorId][storeName=$storeName]"

private def stateCheckpointPartitionsLocation: Path = {
new Path(resolvedCpLocation, s"$operatorId")
}
}
Loading