From 7f2b74eaef983f8cab0c340efb507e09d6db7a85 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 27 Jun 2019 21:05:17 +0900 Subject: [PATCH 1/5] [SPARK-28191][SS] New data source - state - reader part --- ...pache.spark.sql.sources.DataSourceRegister | 1 + .../datasources/v2/state/CheckpointUtil.scala | 109 ++++++ .../datasources/v2/state/SchemaUtil.scala | 25 ++ .../v2/state/StateDataSourceV2.scala | 71 ++++ .../v2/state/StatePartitionReader.scala | 71 ++++ .../state/StatePartitionReaderFactory.scala | 39 +++ .../datasources/v2/state/StateScan.scala | 112 +++++++ .../v2/state/StateSchemaExtractor.scala | 87 +++++ .../datasources/v2/state/StateTable.scala | 74 +++++ .../v2/state/StateDataSourceV2ReadSuite.scala | 162 +++++++++ .../v2/state/StateSchemaExtractorSuite.scala | 109 ++++++ .../sources/v2/state/StateStoreTestBase.scala | 311 ++++++++++++++++++ 12 files changed, 1171 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/CheckpointUtil.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/SchemaUtil.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderFactory.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScan.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateSchemaExtractor.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateSchemaExtractorSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateStoreTestBase.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index c0b8b270bab1f..7891d1c047eac 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,6 +5,7 @@ org.apache.spark.sql.execution.datasources.noop.NoopDataSource org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.v2.parquet.ParquetDataSourceV2 org.apache.spark.sql.execution.datasources.v2.text.TextDataSourceV2 +org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceV2 org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/CheckpointUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/CheckpointUtil.scala new file mode 100644 index 0000000000000..8c3256cac0d6c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/CheckpointUtil.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2.state + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileUtil, Path} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata} + +/** + * Providing features to deal with checkpoint, like creating savepoint. + */ +object CheckpointUtil { + + /** + * Create savepoint from existing checkpoint. + * OffsetLog and CommitLog will be purged based on newLastBatchId. + * Use `additionalMetadataConf` to modify metadata configuration: you may want to modify it + * when rescaling state, or migrate state format version. + * e.g. when rescaling, pass Map(SQLConf.SHUFFLE_PARTITIONS.key -> newShufflePartitions.toString) + * + * @param sparkSession spark session + * @param checkpointRoot the root path of existing checkpoint + * @param newCheckpointRoot the root path of new savepoint - target directory should be empty + * @param newLastBatchId the new last batch ID - it needs to be one of committed batch ID + * @param additionalMetadataConf the configuration to add to existing metadata configuration + * @param excludeState whether to exclude state directory + */ + def createSavePoint( + sparkSession: SparkSession, + checkpointRoot: String, + newCheckpointRoot: String, + newLastBatchId: Long, + additionalMetadataConf: Map[String, String], + excludeState: Boolean = false): Unit = { + val hadoopConf = sparkSession.sessionState.newHadoopConf() + + val src = new Path(resolve(hadoopConf, checkpointRoot)) + val srcFs = src.getFileSystem(hadoopConf) + val dst = new Path(resolve(hadoopConf, newCheckpointRoot)) + val dstFs = dst.getFileSystem(hadoopConf) + + if (dstFs.listFiles(dst, false).hasNext) { + throw new IllegalArgumentException("Destination directory should be empty.") + } + + dstFs.mkdirs(dst) + + // copy content of src directory to dst directory + srcFs.listStatus(src).foreach { fs => + val path = fs.getPath + val fileName = path.getName + if (fileName == "state" && excludeState) { + // pass + } else { + FileUtil.copy(srcFs, path, dstFs, new Path(dst, fileName), + false, false, hadoopConf) + } + } + + val offsetLog = new OffsetSeqLog(sparkSession, new Path(dst, "offsets").toString) + val logForBatch = offsetLog.get(newLastBatchId) match { + case Some(log) => log + case None => throw new IllegalStateException("offset log for batch should be exist") + } + + val newMetadata = logForBatch.metadata match { + case Some(md) => + val newMap = md.conf ++ additionalMetadataConf + Some(md.copy(conf = newMap)) + case None => + Some(OffsetSeqMetadata(conf = additionalMetadataConf)) + } + + val newLogForBatch = logForBatch.copy(metadata = newMetadata) + + // we will restart from last batch + 1: overwrite the last batch with new configuration + offsetLog.purgeAfter(newLastBatchId - 1) + offsetLog.add(newLastBatchId, newLogForBatch) + + val commitLog = new CommitLog(sparkSession, new Path(dst, "commits").toString) + commitLog.purgeAfter(newLastBatchId) + + // state doesn't expose purge mechanism as its interface + // assuming state would work with overwriting batch files when it replays previous batch + } + + private def resolve(hadoopConf: Configuration, cpLocation: String): String = { + val checkpointPath = new Path(cpLocation) + val fs = checkpointPath.getFileSystem(hadoopConf) + checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/SchemaUtil.scala new file mode 100644 index 0000000000000..5dcbdbb89e90d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/SchemaUtil.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import org.apache.spark.sql.types.{DataType, StructType} + +object SchemaUtil { + def getSchemaAsDataType(schema: StructType, fieldName: String): DataType = { + schema(schema.getFieldIndex(fieldName).get).dataType + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala new file mode 100644 index 0000000000000..cbed217d105a1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.util +import java.util.Map + +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.connector.catalog.{Table, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.execution.streaming.state.StateStoreId +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class StateDataSourceV2 extends TableProvider with DataSourceRegister { + + import StateDataSourceV2._ + + lazy val session = SparkSession.active + + override def shortName(): String = "state" + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + val checkpointLocation = Option(properties.get(PARAM_CHECKPOINT_LOCATION)).orElse { + throw new AnalysisException(s"'$PARAM_CHECKPOINT_LOCATION' must be specified.") + }.get + + val version = Option(properties.get(PARAM_VERSION)).map(_.toInt).orElse { + throw new AnalysisException(s"'$PARAM_VERSION' must be specified.") + }.get + + val operatorId = Option(properties.get(PARAM_OPERATOR_ID)).map(_.toInt).orElse { + throw new AnalysisException(s"'$PARAM_OPERATOR_ID' must be specified.") + }.get + + val storeName = Option(properties.get(PARAM_STORE_NAME)) + .orElse(Some(StateStoreId.DEFAULT_STORE_NAME)).get + + new StateTable(session, schema, checkpointLocation, version, operatorId, storeName) + } + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = + throw new UnsupportedOperationException("Schema should be explicitly specified.") + + override def supportsExternalMetadata(): Boolean = true +} + +object StateDataSourceV2 { + val PARAM_CHECKPOINT_LOCATION = "checkpointLocation" + val PARAM_VERSION = "version" + val PARAM_OPERATOR_ID = "operatorId" + val PARAM_STORE_NAME = "storeName" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala new file mode 100644 index 0000000000000..91fa102387c00 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProviderId} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +class StatePartitionReader( + storeConf: StateStoreConf, + hadoopConf: SerializableConfiguration, + partition: StateStoreInputPartition, + schema: StructType) extends PartitionReader[InternalRow] { + + private val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] + private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType] + + private lazy val iter = { + val stateStoreId = StateStoreId(partition.stateCheckpointRootLocation, + partition.operatorId, partition.partition, partition.storeName) + val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) + + val store = StateStore.get(stateStoreProviderId, keySchema, valueSchema, + indexOrdinal = None, version = partition.version, storeConf = storeConf, + hadoopConf = hadoopConf.value) + + store.iterator().map(pair => unifyStateRowPair((pair.key, pair.value))) + } + + private var current: InternalRow = _ + + override def next(): Boolean = { + if (iter.hasNext) { + current = iter.next() + true + } else { + current = null + false + } + } + + override def get(): InternalRow = current + + override def close(): Unit = { + current = null + } + + private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = { + val row = new GenericInternalRow(2) + row.update(0, pair._1) + row.update(1, pair._2) + row + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderFactory.scala new file mode 100644 index 0000000000000..f4cab378aa8eb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderFactory.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.execution.streaming.state.StateStoreConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +class StatePartitionReaderFactory( + storeConf: StateStoreConf, + hadoopConf: SerializableConfiguration, + schema: StructType) extends PartitionReaderFactory { + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val part = partition match { + case p: StateStoreInputPartition => p + case e => throw new IllegalStateException("Expected StateStorePartition but other type of " + + s"partition passed - $e") + } + + new StatePartitionReader(storeConf, hadoopConf, part, schema) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScan.scala new file mode 100644 index 0000000000000..9f4c9b45ba5f2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScan.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.util.UUID + +import scala.util.Try + +import org.apache.hadoop.fs.{Path, PathFilter} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder} +import org.apache.spark.sql.execution.streaming.state.StateStoreConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +class StateScanBuilder( + session: SparkSession, + schema: StructType, + stateCheckpointRootLocation: String, + version: Long, + operatorId: Long, + storeName: String) extends ScanBuilder { + override def build(): Scan = new StateScan(session, schema, stateCheckpointRootLocation, + version, operatorId, storeName) +} + +class StateStoreInputPartition( + val partition: Int, + val queryId: UUID, + val stateCheckpointRootLocation: String, + val version: Long, + val operatorId: Long, + val storeName: String) extends InputPartition + +class StateScan( + session: SparkSession, + schema: StructType, + stateCheckpointRootLocation: String, + version: Long, + operatorId: Long, + storeName: String) extends Scan with Batch { + + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + private val hadoopConfBroadcast = session.sparkContext.broadcast( + new SerializableConfiguration(session.sessionState.newHadoopConf())) + + private val resolvedCpLocation = { + val checkpointPath = new Path(stateCheckpointRootLocation) + val fs = checkpointPath.getFileSystem(session.sessionState.newHadoopConf()) + checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString + } + + override def readSchema(): StructType = schema + + override def planInputPartitions(): Array[InputPartition] = { + val fs = stateCheckpointPartitionsLocation.getFileSystem(hadoopConfBroadcast.value.value) + val partitions = fs.listStatus(stateCheckpointPartitionsLocation, new PathFilter() { + override def accept(path: Path): Boolean = { + fs.isDirectory(path) && Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0 + } + }) + + if (partitions.headOption.isEmpty) { + Array.empty[InputPartition] + } else { + // just a dummy query id because we are actually not running streaming query + val queryId = UUID.randomUUID() + + val partitionsSorted = partitions.sortBy(fs => fs.getPath.getName.toInt) + val partitionNums = partitionsSorted.map(_.getPath.getName.toInt) + // assuming no same number - they're directories hence no same name + val head = partitionNums.head + val tail = partitionNums(partitionNums.length - 1) + assert(head == 0, "Partition should start with 0") + assert((tail - head + 1) == partitionNums.length, + s"No continuous partitions in state: $partitionNums") + + partitionNums.map { + pn => new StateStoreInputPartition(pn, queryId, stateCheckpointRootLocation, + version, operatorId, storeName) + }.toArray + } + } + + override def createReaderFactory(): PartitionReaderFactory = + new StatePartitionReaderFactory(new StateStoreConf(session.sessionState.conf), + hadoopConfBroadcast.value, schema) + + override def toBatch: Batch = this + + override def description(): String = s"StateScan [stateCpLocation=$stateCheckpointRootLocation]" + + s"[version=$version][operatorId=$operatorId][storeName=$storeName]" + + private def stateCheckpointPartitionsLocation: Path = { + new Path(resolvedCpLocation, s"$operatorId") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateSchemaExtractor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateSchemaExtractor.scala new file mode 100644 index 0000000000000..b470b377b4192 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateSchemaExtractor.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.util.UUID + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.execution.datasources.v2.state.StateSchemaExtractor.{StateKind, StateSchemaInfo} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + +/** + * This class enables extracting state schema and its format version via analyzing + * the streaming query. The query should have its state operators but it should exclude sink(s). + * + * Note that it only returns which can be extracted by this class, so number of state + * in given query may not be same as returned number of schema information. + */ +class StateSchemaExtractor(spark: SparkSession) extends Logging { + + def extract(query: DataFrame): Seq[StateSchemaInfo] = { + require(query.isStreaming, "Given query is not a streaming query!") + + val queryExecution = new IncrementalExecution(spark, query.logicalPlan, + OutputMode.Update(), "", UUID.randomUUID(), UUID.randomUUID(), + 0, OffsetSeqMetadata()) + + // TODO: handle Streaming Join (if possible), etc. + queryExecution.executedPlan.collect { + case store: StateStoreSaveExec => + val stateFormatVersion = store.stateFormatVersion + val keySchema = store.keyExpressions.toStructType + val valueSchema = store.stateManager.getStateValueSchema + store.stateInfo match { + case Some(stInfo) => + val operatorId = stInfo.operatorId + StateSchemaInfo(operatorId, StateKind.StreamingAggregation, + stateFormatVersion, keySchema, valueSchema) + + case None => throw new IllegalStateException("State information not set!") + } + + case store: FlatMapGroupsWithStateExec => + val stateFormatVersion = store.stateFormatVersion + val keySchema = store.groupingAttributes.toStructType + val valueSchema = store.stateManager.stateSchema + store.stateInfo match { + case Some(stInfo) => + val operatorId = stInfo.operatorId + StateSchemaInfo(operatorId, StateKind.FlatMapGroupsWithState, + stateFormatVersion, keySchema, valueSchema) + + case None => throw new IllegalStateException("State information not set!") + } + } + } + +} + +object StateSchemaExtractor { + object StateKind extends Enumeration { + val StreamingAggregation, StreamingJoin, FlatMapGroupsWithState = Value + } + + case class StateSchemaInfo( + opId: Long, + stateKind: StateKind.Value, + formatVersion: Int, + keySchema: StructType, + valueSchema: StructType) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala new file mode 100644 index 0000000000000..79688b1fe685f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class StateTable( + session: SparkSession, + override val schema: StructType, + checkpointLocation: String, + version: Int, + operatorId: Int, + storeName: String) + extends Table with SupportsRead { + + import StateTable._ + + if (!isValidSchema(schema)) { + throw new AnalysisException("The fields of schema should be 'key' and 'value', " + + "and each field should have corresponding fields (they should be a StructType)") + } + + override def name(): String = + s"state-table-cp-$checkpointLocation-ver-$version-operator-$operatorId-store-$storeName" + + override def capabilities(): util.Set[TableCapability] = CAPABILITY + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = + new StateScanBuilder(session, schema, checkpointLocation, version, operatorId, storeName) + + override def properties(): util.Map[String, String] = Map( + "checkpointLocation" -> checkpointLocation, + "version" -> version.toString, + "operatorId" -> operatorId.toString, + "storeName" -> storeName).asJava + + private def isValidSchema(schema: StructType): Boolean = { + if (schema.fieldNames.toSeq != Seq("key", "value")) { + false + } else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { + false + } else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) { + false + } else { + true + } + } +} + +object StateTable { + private val CAPABILITY = Set(TableCapability.BATCH_READ).asJava +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala new file mode 100644 index 0000000000000..e06fa9ed810cc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.state + +import java.io.File + +import org.scalatest.{Assertions, BeforeAndAfterAll} + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceV2 +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.internal.SQLConf + +class StateDataSourceV2ReadSuite + extends StateStoreTestBase + with BeforeAndAfterAll + with Assertions { + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + test("reading state from simple aggregation - state format version 1") { + withSQLConf(Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1"): _*) { + withTempDir { tempDir => + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + val stateSchema = getSchemaForLargeDataStreamingAggregationQuery(1) + + val operatorId = 0 + val batchId = 1 + + val stateReadDf = spark.read + .format("state") + .schema(stateSchema) + .option(StateDataSourceV2.PARAM_CHECKPOINT_LOCATION, + new File(tempDir, "state").getAbsolutePath) + .option(StateDataSourceV2.PARAM_VERSION, batchId + 1) + .option(StateDataSourceV2.PARAM_OPERATOR_ID, operatorId) + .load() + + logInfo(s"Schema: ${stateReadDf.schema.treeString}") + + checkAnswer( + stateReadDf + .selectExpr("key.groupKey AS key_groupKey", "value.groupKey AS value_groupKey", + "value.cnt AS value_cnt", "value.sum AS value_sum", "value.max AS value_max", + "value.min AS value_min"), + Seq( + Row(0, 0, 4, 60, 30, 0), // 0, 10, 20, 30 + Row(1, 1, 4, 64, 31, 1), // 1, 11, 21, 31 + Row(2, 2, 4, 68, 32, 2), // 2, 12, 22, 32 + Row(3, 3, 4, 72, 33, 3), // 3, 13, 23, 33 + Row(4, 4, 4, 76, 34, 4), // 4, 14, 24, 34 + Row(5, 5, 4, 80, 35, 5), // 5, 15, 25, 35 + Row(6, 6, 4, 84, 36, 6), // 6, 16, 26, 36 + Row(7, 7, 4, 88, 37, 7), // 7, 17, 27, 37 + Row(8, 8, 4, 92, 38, 8), // 8, 18, 28, 38 + Row(9, 9, 4, 96, 39, 9) // 9, 19, 29, 39 + ) + ) + } + } + } + + test("reading state from simple aggregation - state format version 2") { + withSQLConf(Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2"): _*) { + withTempDir { tempDir => + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + val stateSchema = getSchemaForLargeDataStreamingAggregationQuery(2) + + val operatorId = 0 + val batchId = 1 + + val stateReadDf = spark.read + .format("state") + .schema(stateSchema) + .option(StateDataSourceV2.PARAM_CHECKPOINT_LOCATION, + new File(tempDir, "state").getAbsolutePath) + .option(StateDataSourceV2.PARAM_VERSION, batchId + 1) + .option(StateDataSourceV2.PARAM_OPERATOR_ID, operatorId) + .load() + + logInfo(s"Schema: ${stateReadDf.schema.treeString}") + + checkAnswer( + stateReadDf + .selectExpr("key.groupKey AS key_groupKey", "value.cnt AS value_cnt", + "value.sum AS value_sum", "value.max AS value_max", "value.min AS value_min"), + Seq( + Row(0, 4, 60, 30, 0), // 0, 10, 20, 30 + Row(1, 4, 64, 31, 1), // 1, 11, 21, 31 + Row(2, 4, 68, 32, 2), // 2, 12, 22, 32 + Row(3, 4, 72, 33, 3), // 3, 13, 23, 33 + Row(4, 4, 76, 34, 4), // 4, 14, 24, 34 + Row(5, 4, 80, 35, 5), // 5, 15, 25, 35 + Row(6, 4, 84, 36, 6), // 6, 16, 26, 36 + Row(7, 4, 88, 37, 7), // 7, 17, 27, 37 + Row(8, 4, 92, 38, 8), // 8, 18, 28, 38 + Row(9, 4, 96, 39, 9) // 9, 19, 29, 39 + ) + ) + } + } + } + + test("reading state from simple aggregation - composite key") { + withSQLConf(Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2"): _*) { + withTempDir { tempDir => + runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) + + val stateSchema = getSchemaForCompositeKeyStreamingAggregationQuery(2) + + val operatorId = 0 + val batchId = 1 + + val stateReadDf = spark.read + .format("state") + .schema(stateSchema) + .option(StateDataSourceV2.PARAM_CHECKPOINT_LOCATION, + new File(tempDir, "state").getAbsolutePath) + .option(StateDataSourceV2.PARAM_VERSION, batchId + 1) + .option(StateDataSourceV2.PARAM_OPERATOR_ID, operatorId) + .load() + + logInfo(s"Schema: ${stateReadDf.schema.treeString}") + + checkAnswer( + stateReadDf + .selectExpr("key.groupKey AS key_groupKey", "key.fruit AS key_fruit", + "value.cnt AS value_cnt", "value.sum AS value_sum", "value.max AS value_max", + "value.min AS value_min"), + Seq( + Row(0, "Apple", 2, 6, 6, 0), + Row(1, "Banana", 2, 8, 7, 1), + Row(0, "Strawberry", 2, 10, 8, 2), + Row(1, "Apple", 2, 12, 9, 3), + Row(0, "Banana", 2, 14, 10, 4), + Row(1, "Strawberry", 1, 5, 5, 5) + ) + ) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateSchemaExtractorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateSchemaExtractorSuite.scala new file mode 100644 index 0000000000000..fe28eb121509a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateSchemaExtractorSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.state + +import org.scalatest.{Assertions, BeforeAndAfterAll} + +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.execution.datasources.v2.state.{SchemaUtil, StateSchemaExtractor} +import org.apache.spark.sql.execution.datasources.v2.state.StateSchemaExtractor.StateKind +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} + +class StateSchemaExtractorSuite + extends StateStoreTestBase + with BeforeAndAfterAll + with Assertions { + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + Seq(1, 2).foreach { ver => + test(s"extract schema from streaming aggregation query - state format v$ver") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> ver.toString) { + val aggregated = getCompositeKeyStreamingAggregationQuery + + val stateSchema = getSchemaForCompositeKeyStreamingAggregationQuery(ver) + val expectedKeySchema = SchemaUtil.getSchemaAsDataType(stateSchema, "key") + .asInstanceOf[StructType] + val expectedValueSchema = SchemaUtil.getSchemaAsDataType(stateSchema, "value") + .asInstanceOf[StructType] + + val schemaInfos = new StateSchemaExtractor(spark).extract(aggregated.toDF()) + assert(schemaInfos.length === 1) + val schemaInfo = schemaInfos.head + assert(schemaInfo.opId === 0) + assert(schemaInfo.formatVersion === ver) + assert(schemaInfo.stateKind === StateKind.StreamingAggregation) + + assert(compareSchemaWithoutName(schemaInfo.keySchema, expectedKeySchema), + s"Even without column names, ${schemaInfo.keySchema} did not equal $expectedKeySchema") + assert(compareSchemaWithoutName(schemaInfo.valueSchema, expectedValueSchema), + s"Even without column names, ${schemaInfo.valueSchema} did not equal " + + s"$expectedValueSchema") + } + } + } + + Seq(1, 2).foreach { ver => + test(s"extract schema from flatMapGroupsWithState query - state format v$ver") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> ver.toString) { + // This is borrowed from StateStoreTest, runFlatMapGroupsWithStateQuery + val aggregated = getFlatMapGroupsWithStateQuery + + val expectedKeySchema = new StructType().add("value", StringType, nullable = true) + + val expectedValueSchema = if (ver == 1) { + Encoders.product[SessionInfo].schema + .add("timeoutTimestamp", IntegerType, nullable = false) + } else { + // ver == 2 + new StructType() + .add("groupState", Encoders.product[SessionInfo].schema) + .add("timeoutTimestamp", LongType, nullable = false) + } + + val schemaInfos = new StateSchemaExtractor(spark).extract(aggregated.toDF()) + assert(schemaInfos.length === 1) + val schemaInfo = schemaInfos.head + assert(schemaInfo.opId === 0) + assert(schemaInfo.stateKind === StateKind.FlatMapGroupsWithState) + assert(schemaInfo.formatVersion === ver) + + assert(compareSchemaWithoutName(schemaInfo.keySchema, expectedKeySchema), + s"Even without column names, ${schemaInfo.keySchema} did not equal $expectedKeySchema") + assert(compareSchemaWithoutName(schemaInfo.valueSchema, expectedValueSchema), + s"Even without column names, ${schemaInfo.valueSchema} did not equal " + + s"$expectedValueSchema") + } + } + } + + private def compareSchemaWithoutName(s1: StructType, s2: StructType): Boolean = { + if (s1.length != s2.length) { + false + } else { + s1.zip(s2).forall { case (column1, column2) => + column1.dataType == column2.dataType && column1.nullable == column2.nullable + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateStoreTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateStoreTestBase.scala new file mode 100644 index 0000000000000..232f3e06d118e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateStoreTestBase.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.state + +import java.io.File +import java.sql.Timestamp + +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.util.Utils + +trait StateStoreTestBase extends StreamTest { + import testImplicits._ + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + protected def withTempCheckpoints(body: (File, File) => Unit) { + val src = Utils.createTempDir(namePrefix = "streaming.old") + val tmp = Utils.createTempDir(namePrefix = "streaming.new") + try { + body(src, tmp) + } finally { + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + } + + protected def runCompositeKeyStreamingAggregationQuery( + checkpointRoot: String): Unit = { + val inputData = MemoryStream[Int] + val aggregated = getCompositeKeyStreamingAggregationQuery(inputData) + + testStream(aggregated, OutputMode.Update)( + StartStream(checkpointLocation = checkpointRoot), + // batch 0 + AddData(inputData, 0 to 5: _*), + CheckLastBatch( + (0, "Apple", 1, 0, 0, 0), + (1, "Banana", 1, 1, 1, 1), + (0, "Strawberry", 1, 2, 2, 2), + (1, "Apple", 1, 3, 3, 3), + (0, "Banana", 1, 4, 4, 4), + (1, "Strawberry", 1, 5, 5, 5) + ), + // batch 1 + AddData(inputData, 6 to 10: _*), + // state also contains (1, "Strawberry", 1, 5, 5, 5) but not updated here + CheckLastBatch( + (0, "Apple", 2, 6, 6, 0), // 0, 6 + (1, "Banana", 2, 8, 7, 1), // 1, 7 + (0, "Strawberry", 2, 10, 8, 2), // 2, 8 + (1, "Apple", 2, 12, 9, 3), // 3, 9 + (0, "Banana", 2, 14, 10, 4) // 4, 10 + ), + StopStream, + StartStream(checkpointLocation = checkpointRoot), + // batch 2 + AddData(inputData, 3, 2, 1), + CheckLastBatch( + (1, "Banana", 3, 9, 7, 1), // 1, 7, 1 + (0, "Strawberry", 3, 12, 8, 2), // 2, 8, 2 + (1, "Apple", 3, 15, 9, 3) // 3, 9, 3 + ) + ) + } + + protected def getCompositeKeyStreamingAggregationQuery + : Dataset[(Int, String, Long, Long, Int, Int)] = { + getCompositeKeyStreamingAggregationQuery(MemoryStream[Int]) + } + + protected def getCompositeKeyStreamingAggregationQuery( + inputData: MemoryStream[Int]): Dataset[(Int, String, Long, Long, Int, Int)] = { + inputData.toDF() + .selectExpr("value", "value % 2 AS groupKey", + "(CASE value % 3 WHEN 0 THEN 'Apple' WHEN 1 THEN 'Banana' ELSE 'Strawberry' END) AS fruit") + .groupBy($"groupKey", $"fruit") + .agg( + count("*").as("cnt"), + sum("value").as("sum"), + max("value").as("max"), + min("value").as("min") + ) + .as[(Int, String, Long, Long, Int, Int)] + } + + protected def getSchemaForCompositeKeyStreamingAggregationQuery( + formatVersion: Int): StructType = { + val stateKeySchema = new StructType() + .add("groupKey", IntegerType) + .add("fruit", StringType, nullable = false) + + var stateValueSchema = formatVersion match { + case 1 => + new StructType().add("groupKey", IntegerType).add("fruit", StringType, nullable = false) + case 2 => new StructType() + case v => throw new IllegalArgumentException(s"Not valid format version $v") + } + + stateValueSchema = stateValueSchema + .add("cnt", LongType, nullable = false) + .add("sum", LongType) + .add("max", IntegerType) + .add("min", IntegerType) + + new StructType() + .add("key", stateKeySchema) + .add("value", stateValueSchema) + } + + protected def runLargeDataStreamingAggregationQuery( + checkpointRoot: String): Unit = { + val inputData = MemoryStream[Int] + val aggregated = getLargeDataStreamingAggregationQuery(inputData) + + // check with more data - leverage full partitions + testStream(aggregated, OutputMode.Update)( + StartStream(checkpointLocation = checkpointRoot), + // batch 0 + AddData(inputData, 0 until 20: _*), + CheckLastBatch( + (0, 2, 10, 10, 0), // 0, 10 + (1, 2, 12, 11, 1), // 1, 11 + (2, 2, 14, 12, 2), // 2, 12 + (3, 2, 16, 13, 3), // 3, 13 + (4, 2, 18, 14, 4), // 4, 14 + (5, 2, 20, 15, 5), // 5, 15 + (6, 2, 22, 16, 6), // 6, 16 + (7, 2, 24, 17, 7), // 7, 17 + (8, 2, 26, 18, 8), // 8, 18 + (9, 2, 28, 19, 9) // 9, 19 + ), + // batch 1 + AddData(inputData, 20 until 40: _*), + CheckLastBatch( + (0, 4, 60, 30, 0), // 0, 10, 20, 30 + (1, 4, 64, 31, 1), // 1, 11, 21, 31 + (2, 4, 68, 32, 2), // 2, 12, 22, 32 + (3, 4, 72, 33, 3), // 3, 13, 23, 33 + (4, 4, 76, 34, 4), // 4, 14, 24, 34 + (5, 4, 80, 35, 5), // 5, 15, 25, 35 + (6, 4, 84, 36, 6), // 6, 16, 26, 36 + (7, 4, 88, 37, 7), // 7, 17, 27, 37 + (8, 4, 92, 38, 8), // 8, 18, 28, 38 + (9, 4, 96, 39, 9) // 9, 19, 29, 39 + ), + StopStream, + StartStream(checkpointLocation = checkpointRoot), + // batch 2 + AddData(inputData, 0, 1, 2), + CheckLastBatch( + (0, 5, 60, 30, 0), // 0, 10, 20, 30, 0 + (1, 5, 65, 31, 1), // 1, 11, 21, 31, 1 + (2, 5, 70, 32, 2) // 2, 12, 22, 32, 2 + ) + ) + } + + protected def getLargeDataStreamingAggregationQuery: Dataset[(Int, Long, Long, Int, Int)] = { + getLargeDataStreamingAggregationQuery(MemoryStream[Int]) + } + + protected def getLargeDataStreamingAggregationQuery( + inputData: MemoryStream[Int]): Dataset[(Int, Long, Long, Int, Int)] = { + inputData.toDF() + .selectExpr("value", "value % 10 AS groupKey") + .groupBy($"groupKey") + .agg( + count("*").as("cnt"), + sum("value").as("sum"), + max("value").as("max"), + min("value").as("min") + ) + .as[(Int, Long, Long, Int, Int)] + } + + protected def getSchemaForLargeDataStreamingAggregationQuery(formatVersion: Int): StructType = { + val stateKeySchema = new StructType() + .add("groupKey", IntegerType) + + var stateValueSchema = formatVersion match { + case 1 => new StructType().add("groupKey", IntegerType) + case 2 => new StructType() + case v => throw new IllegalArgumentException(s"Not valid format version $v") + } + + stateValueSchema = stateValueSchema + .add("cnt", LongType) + .add("sum", LongType) + .add("max", IntegerType) + .add("min", IntegerType) + + new StructType() + .add("key", stateKeySchema) + .add("value", stateValueSchema) + } + + protected def runFlatMapGroupsWithStateQuery(checkpointRoot: String): Unit = { + val clock = new StreamManualClock + + val inputData = MemoryStream[(String, Long)] + val remapped = getFlatMapGroupsWithStateQuery(inputData) + + testStream(remapped, OutputMode.Update)( + // batch 0 + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock, + checkpointLocation = checkpointRoot), + AddData(inputData, ("hello world", 1L), ("hello scala", 2L)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer( + ("hello", 2, 1000, false), + ("world", 1, 0, false), + ("scala", 1, 0, false) + ), + // batch 1 + AddData(inputData, ("hello world", 3L), ("hello scala", 4L)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer( + ("hello", 4, 3000, false), + ("world", 2, 2000, false), + ("scala", 2, 2000, false) + ) + ) + } + + protected def getFlatMapGroupsWithStateQuery: Dataset[(String, Int, Long, Boolean)] = { + getFlatMapGroupsWithStateQuery(MemoryStream[(String, Long)]) + } + + protected def getFlatMapGroupsWithStateQuery( + inputData: MemoryStream[(String, Long)]): Dataset[(String, Int, Long, Boolean)] = { + // scalastyle:off line.size.limit + // This test code is borrowed from Sessionization example, with modification a bit to run with testStream + // https://github.com/apache/spark/blob/v2.4.1/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala + // scalastyle:on + + val events = inputData.toDF() + .as[(String, Timestamp)] + .flatMap { case (line, timestamp) => + line.split(" ").map(word => Event(sessionId = word, timestamp)) + } + + val sessionUpdates = events + .groupByKey(event => event.sessionId) + .mapGroupsWithState[SessionInfo, SessionUpdate](GroupStateTimeout.ProcessingTimeTimeout) { + + case (sessionId: String, events: Iterator[Event], state: GroupState[SessionInfo]) => + if (state.hasTimedOut) { + val finalUpdate = + SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = true) + state.remove() + finalUpdate + } else { + val timestamps = events.map(_.timestamp.getTime).toSeq + val updatedSession = if (state.exists) { + val oldSession = state.get + SessionInfo( + oldSession.numEvents + timestamps.size, + oldSession.startTimestampMs, + math.max(oldSession.endTimestampMs, timestamps.max)) + } else { + SessionInfo(timestamps.size, timestamps.min, timestamps.max) + } + state.update(updatedSession) + + state.setTimeoutDuration("10 seconds") + SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = false) + } + } + + sessionUpdates.map(si => (si.id, si.numEvents, si.durationMs, si.expired)) + } +} + +case class Event(sessionId: String, timestamp: Timestamp) + +case class SessionInfo( + numEvents: Int, + startTimestampMs: Long, + endTimestampMs: Long) { + def durationMs: Long = endTimestampMs - startTimestampMs +} + +case class SessionUpdate( + id: String, + durationMs: Long, + numEvents: Int, + expired: Boolean) From cd7a74a0bf40c72bbdb3590e6bdbd50fb3f71139 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Fri, 11 Dec 2020 16:56:25 +0900 Subject: [PATCH 2/5] Incorporate the schema information from SPARK-27237 --- .../v2/state/StateDataSourceV2.scala | 32 ++++- .../StateSchemaCompatibilityChecker.scala | 59 +-------- .../state/StateSchemaFileManager.scala | 81 ++++++++++++ .../v2/state/StateDataSourceV2ReadSuite.scala | 125 ++++++++++-------- 4 files changed, 183 insertions(+), 114 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaFileManager.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala index cbed217d105a1..b73289d207093 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.execution.datasources.v2.state import java.util -import java.util.Map import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.connector.catalog.{Table, TableProvider} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.streaming.state.StateStoreId +import org.apache.spark.sql.execution.streaming.state.{StateSchemaFileManager, StateStore, StateStoreId} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap class StateDataSourceV2 extends TableProvider with DataSourceRegister { @@ -57,8 +56,31 @@ class StateDataSourceV2 extends TableProvider with DataSourceRegister { new StateTable(session, schema, checkpointLocation, version, operatorId, storeName) } - override def inferSchema(options: CaseInsensitiveStringMap): StructType = - throw new UnsupportedOperationException("Schema should be explicitly specified.") + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + val checkpointLocation = Option(options.get(PARAM_CHECKPOINT_LOCATION)).orElse { + throw new AnalysisException(s"'$PARAM_CHECKPOINT_LOCATION' must be specified.") + }.get + + val operatorId = Option(options.get(PARAM_OPERATOR_ID)).map(_.toInt).orElse { + throw new AnalysisException(s"'$PARAM_OPERATOR_ID' must be specified.") + }.get + + val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA + val storeName = Option(options.get(PARAM_STORE_NAME)) + .orElse(Some(StateStoreId.DEFAULT_STORE_NAME)).get + + val storeId = new StateStoreId(checkpointLocation, operatorId, partitionId, storeName) + val manager = new StateSchemaFileManager(storeId, session.sessionState.newHadoopConf()) + if (manager.fileExist()) { + val (keySchema, valueSchema) = manager.readSchema() + new StructType() + .add("key", keySchema) + .add("value", valueSchema) + } else { + throw new UnsupportedOperationException("Schema information file doesn't exist - schema " + + "should be explicitly specified.") + } + } override def supportsExternalMetadata(): Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 4ac12c089c0d3..4fde8a1d16ba1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -31,16 +29,12 @@ class StateSchemaCompatibilityChecker( providerId: StateStoreProviderId, hadoopConf: Configuration) extends Logging { - private val storeCpLocation = providerId.storeId.storeCheckpointLocation() - private val fm = CheckpointFileManager.create(storeCpLocation, hadoopConf) - private val schemaFileLocation = schemaFile(storeCpLocation) - - fm.mkdirs(schemaFileLocation.getParent) + private val stateFileManager = new StateSchemaFileManager(providerId.storeId, hadoopConf) def check(keySchema: StructType, valueSchema: StructType): Unit = { - if (fm.exists(schemaFileLocation)) { + if (stateFileManager.fileExist()) { logDebug(s"Schema file for provider $providerId exists. Comparing with provided schema.") - val (storedKeySchema, storedValueSchema) = readSchemaFile() + val (storedKeySchema, storedValueSchema) = stateFileManager.readSchema() if (storedKeySchema.equals(keySchema) && storedValueSchema.equals(valueSchema)) { // schema is exactly same } else if (!schemasCompatible(storedKeySchema, keySchema) || @@ -64,55 +58,10 @@ class StateSchemaCompatibilityChecker( } else { // schema doesn't exist, create one now logDebug(s"Schema file for provider $providerId doesn't exist. Creating one.") - createSchemaFile(keySchema, valueSchema) + stateFileManager.writeSchema(keySchema, valueSchema) } } private def schemasCompatible(storedSchema: StructType, schema: StructType): Boolean = DataType.equalsIgnoreNameAndCompatibleNullability(storedSchema, schema) - - private def readSchemaFile(): (StructType, StructType) = { - val inStream = fm.open(schemaFileLocation) - try { - val versionStr = inStream.readUTF() - // Currently we only support version 1, which we can simplify the version validation and - // the parse logic. - val version = MetadataVersionUtil.validateVersion(versionStr, - StateSchemaCompatibilityChecker.VERSION) - require(version == 1) - - val keySchemaStr = inStream.readUTF() - val valueSchemaStr = inStream.readUTF() - - (StructType.fromString(keySchemaStr), StructType.fromString(valueSchemaStr)) - } catch { - case e: Throwable => - logError(s"Fail to read schema file from $schemaFileLocation", e) - throw e - } finally { - inStream.close() - } - } - - private def createSchemaFile(keySchema: StructType, valueSchema: StructType): Unit = { - val outStream = fm.createAtomic(schemaFileLocation, overwriteIfPossible = false) - try { - outStream.writeUTF(s"v${StateSchemaCompatibilityChecker.VERSION}") - outStream.writeUTF(keySchema.json) - outStream.writeUTF(valueSchema.json) - outStream.close() - } catch { - case e: Throwable => - logError(s"Fail to write schema file to $schemaFileLocation", e) - outStream.cancel() - throw e - } - } - - private def schemaFile(storeCpLocation: Path): Path = - new Path(new Path(storeCpLocation, "_metadata"), "schema") -} - -object StateSchemaCompatibilityChecker { - val VERSION = 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaFileManager.scala new file mode 100644 index 0000000000000..712fc24b5af51 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaFileManager.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming.state + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} +import org.apache.spark.sql.types.StructType + +class StateSchemaFileManager(storeId: StateStoreId, hadoopConf: Configuration) extends Logging { + private val storeCpLocation = storeId.storeCheckpointLocation() + private val fm = CheckpointFileManager.create(storeCpLocation, hadoopConf) + private val schemaFileLocation = schemaFile(storeCpLocation) + + def fileExist(): Boolean = fm.exists(schemaFileLocation) + + def readSchema(): (StructType, StructType) = { + val inStream = fm.open(schemaFileLocation) + try { + val versionStr = inStream.readUTF() + // Currently we only support version 1, which we can simplify the version validation and + // the parse logic. + val version = MetadataVersionUtil.validateVersion(versionStr, + StateSchemaFileManager.VERSION) + require(version == 1) + + val keySchemaStr = inStream.readUTF() + val valueSchemaStr = inStream.readUTF() + + (StructType.fromString(keySchemaStr), StructType.fromString(valueSchemaStr)) + } catch { + case e: Throwable => + logError(s"Fail to read schema file from $schemaFileLocation", e) + throw e + } finally { + inStream.close() + } + } + + def writeSchema(keySchema: StructType, valueSchema: StructType): Unit = { + if (!fm.exists(schemaFileLocation.getParent)) { + fm.mkdirs(schemaFileLocation.getParent) + } + + val outStream = fm.createAtomic(schemaFileLocation, overwriteIfPossible = false) + try { + outStream.writeUTF(s"v${StateSchemaFileManager.VERSION}") + outStream.writeUTF(keySchema.json) + outStream.writeUTF(valueSchema.json) + outStream.close() + } catch { + case e: Throwable => + logError(s"Fail to write schema file to $schemaFileLocation", e) + outStream.cancel() + throw e + } + } + + private def schemaFile(storeCpLocation: Path): Path = + new Path(new Path(storeCpLocation, "_metadata"), "schema") +} + +object StateSchemaFileManager { + val VERSION = 1 +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala index e06fa9ed810cc..fe2e61d2deb77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala @@ -36,74 +36,74 @@ class StateDataSourceV2ReadSuite StateStore.stop() } - test("reading state from simple aggregation - state format version 1") { - withSQLConf(Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1"): _*) { - withTempDir { tempDir => - runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + test("simple aggregation, state ver 1, infer schema = false") { + testStreamingAggregation(1, inferSchema = false) + } - val stateSchema = getSchemaForLargeDataStreamingAggregationQuery(1) + test("simple aggregation, state ver 1, infer schema = true") { + testStreamingAggregation(1, inferSchema = true) + } - val operatorId = 0 - val batchId = 1 + test("simple aggregation, state ver 2, infer schema = false") { + testStreamingAggregation(2, inferSchema = false) + } - val stateReadDf = spark.read - .format("state") - .schema(stateSchema) - .option(StateDataSourceV2.PARAM_CHECKPOINT_LOCATION, - new File(tempDir, "state").getAbsolutePath) - .option(StateDataSourceV2.PARAM_VERSION, batchId + 1) - .option(StateDataSourceV2.PARAM_OPERATOR_ID, operatorId) - .load() + test("simple aggregation, state ver 2, infer schema = true") { + testStreamingAggregation(2, inferSchema = true) + } - logInfo(s"Schema: ${stateReadDf.schema.treeString}") + test("composite key aggregation, state ver 1, infer schema = false") { + testStreamingAggregationWithCompositeKey(1, inferSchema = false) + } - checkAnswer( - stateReadDf - .selectExpr("key.groupKey AS key_groupKey", "value.groupKey AS value_groupKey", - "value.cnt AS value_cnt", "value.sum AS value_sum", "value.max AS value_max", - "value.min AS value_min"), - Seq( - Row(0, 0, 4, 60, 30, 0), // 0, 10, 20, 30 - Row(1, 1, 4, 64, 31, 1), // 1, 11, 21, 31 - Row(2, 2, 4, 68, 32, 2), // 2, 12, 22, 32 - Row(3, 3, 4, 72, 33, 3), // 3, 13, 23, 33 - Row(4, 4, 4, 76, 34, 4), // 4, 14, 24, 34 - Row(5, 5, 4, 80, 35, 5), // 5, 15, 25, 35 - Row(6, 6, 4, 84, 36, 6), // 6, 16, 26, 36 - Row(7, 7, 4, 88, 37, 7), // 7, 17, 27, 37 - Row(8, 8, 4, 92, 38, 8), // 8, 18, 28, 38 - Row(9, 9, 4, 96, 39, 9) // 9, 19, 29, 39 - ) - ) - } - } + test("composite key aggregation, state ver 1, infer schema = true") { + testStreamingAggregationWithCompositeKey(1, inferSchema = true) + } + + test("composite key aggregation, state ver 2, infer schema = false") { + testStreamingAggregationWithCompositeKey(2, inferSchema = false) + } + + test("composite key aggregation, ver 2, infer schema = true") { + testStreamingAggregationWithCompositeKey(2, inferSchema = true) } - test("reading state from simple aggregation - state format version 2") { - withSQLConf(Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2"): _*) { + private def testStreamingAggregation(stateVersion: Int, inferSchema: Boolean): Unit = { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString) { withTempDir { tempDir => runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) - val stateSchema = getSchemaForLargeDataStreamingAggregationQuery(2) - val operatorId = 0 val batchId = 1 - val stateReadDf = spark.read + val stateReader = spark.read .format("state") - .schema(stateSchema) .option(StateDataSourceV2.PARAM_CHECKPOINT_LOCATION, new File(tempDir, "state").getAbsolutePath) .option(StateDataSourceV2.PARAM_VERSION, batchId + 1) .option(StateDataSourceV2.PARAM_OPERATOR_ID, operatorId) - .load() + + val stateReadDf = if (inferSchema) { + stateReader.load() + } else { + val stateSchema = getSchemaForLargeDataStreamingAggregationQuery(stateVersion) + stateReader.schema(stateSchema).load() + } logInfo(s"Schema: ${stateReadDf.schema.treeString}") - checkAnswer( + val resultDf = if (inferSchema) { + stateReadDf + .selectExpr("key.groupKey AS key_groupKey", "value.count AS value_cnt", + "value.sum AS value_sum", "value.max AS value_max", "value.min AS value_min") + } else { stateReadDf .selectExpr("key.groupKey AS key_groupKey", "value.cnt AS value_cnt", - "value.sum AS value_sum", "value.max AS value_max", "value.min AS value_min"), + "value.sum AS value_sum", "value.max AS value_max", "value.min AS value_min") + } + + checkAnswer( + resultDf, Seq( Row(0, 4, 60, 30, 0), // 0, 10, 20, 30 Row(1, 4, 64, 31, 1), // 1, 11, 21, 31 @@ -121,32 +121,46 @@ class StateDataSourceV2ReadSuite } } - test("reading state from simple aggregation - composite key") { - withSQLConf(Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2"): _*) { + private def testStreamingAggregationWithCompositeKey( + stateVersion: Int, + inferSchema: Boolean): Unit = { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString) { withTempDir { tempDir => runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) - val stateSchema = getSchemaForCompositeKeyStreamingAggregationQuery(2) - val operatorId = 0 val batchId = 1 - val stateReadDf = spark.read + val stateReader = spark.read .format("state") - .schema(stateSchema) .option(StateDataSourceV2.PARAM_CHECKPOINT_LOCATION, new File(tempDir, "state").getAbsolutePath) .option(StateDataSourceV2.PARAM_VERSION, batchId + 1) .option(StateDataSourceV2.PARAM_OPERATOR_ID, operatorId) - .load() + + val stateReadDf = if (inferSchema) { + stateReader.load() + } else { + val stateSchema = getSchemaForCompositeKeyStreamingAggregationQuery(stateVersion) + stateReader.schema(stateSchema).load() + } logInfo(s"Schema: ${stateReadDf.schema.treeString}") - checkAnswer( + val resultDf = if (inferSchema) { + stateReadDf + .selectExpr("key.groupKey AS key_groupKey", "key.fruit AS key_fruit", + "value.count AS value_cnt", "value.sum AS value_sum", "value.max AS value_max", + "value.min AS value_min") + } else { stateReadDf .selectExpr("key.groupKey AS key_groupKey", "key.fruit AS key_fruit", "value.cnt AS value_cnt", "value.sum AS value_sum", "value.max AS value_max", - "value.min AS value_min"), + "value.min AS value_min") + } + + checkAnswer( + resultDf, Seq( Row(0, "Apple", 2, 6, 6, 0), Row(1, "Banana", 2, 8, 7, 1), @@ -159,4 +173,7 @@ class StateDataSourceV2ReadSuite } } } + + // FIXME: add flatMapGroupsWithState test cases + } From 99b00dbf371cb1b0e568561a6c03a6668a106e24 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Mon, 14 Dec 2020 13:15:12 +0900 Subject: [PATCH 3/5] add test with flatMapGroupWithState, change some params as optional --- .../v2/state/StateDataSourceV2.scala | 43 ++++++-- .../v2/state/StatePartitionReader.scala | 2 +- .../datasources/v2/state/StateTable.scala | 10 +- .../v2/state/StateDataSourceV2ReadSuite.scala | 103 ++++++++++++++---- .../sources/v2/state/StateStoreTestBase.scala | 15 ++- 5 files changed, 134 insertions(+), 39 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala index b73289d207093..082724a4ebcd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala @@ -18,19 +18,22 @@ package org.apache.spark.sql.execution.datasources.v2.state import java.util +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.connector.catalog.{Table, TableProvider} import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.execution.streaming.CommitLog import org.apache.spark.sql.execution.streaming.state.{StateSchemaFileManager, StateStore, StateStoreId} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap class StateDataSourceV2 extends TableProvider with DataSourceRegister { import StateDataSourceV2._ - lazy val session = SparkSession.active + lazy val session: SparkSession = SparkSession.active override def shortName(): String = "state" @@ -42,18 +45,19 @@ class StateDataSourceV2 extends TableProvider with DataSourceRegister { throw new AnalysisException(s"'$PARAM_CHECKPOINT_LOCATION' must be specified.") }.get - val version = Option(properties.get(PARAM_VERSION)).map(_.toInt).orElse { - throw new AnalysisException(s"'$PARAM_VERSION' must be specified.") + val version = Option(properties.get(PARAM_VERSION)).map(_.toLong).orElse { + Some(getLastCommittedBatch(checkpointLocation)) }.get - val operatorId = Option(properties.get(PARAM_OPERATOR_ID)).map(_.toInt).orElse { - throw new AnalysisException(s"'$PARAM_OPERATOR_ID' must be specified.") - }.get + val operatorId = Option(properties.get(PARAM_OPERATOR_ID)).map(_.toInt) + .orElse(Some(0)).get val storeName = Option(properties.get(PARAM_STORE_NAME)) .orElse(Some(StateStoreId.DEFAULT_STORE_NAME)).get - new StateTable(session, schema, checkpointLocation, version, operatorId, storeName) + val stateCheckpointLocation = new Path(checkpointLocation, "state") + new StateTable(session, schema, stateCheckpointLocation.toString, version, operatorId, + storeName) } override def inferSchema(options: CaseInsensitiveStringMap): StructType = { @@ -61,15 +65,16 @@ class StateDataSourceV2 extends TableProvider with DataSourceRegister { throw new AnalysisException(s"'$PARAM_CHECKPOINT_LOCATION' must be specified.") }.get - val operatorId = Option(options.get(PARAM_OPERATOR_ID)).map(_.toInt).orElse { - throw new AnalysisException(s"'$PARAM_OPERATOR_ID' must be specified.") - }.get + val operatorId = Option(options.get(PARAM_OPERATOR_ID)).map(_.toInt) + .orElse(Some(0)).get val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA val storeName = Option(options.get(PARAM_STORE_NAME)) .orElse(Some(StateStoreId.DEFAULT_STORE_NAME)).get - val storeId = new StateStoreId(checkpointLocation, operatorId, partitionId, storeName) + val stateCheckpointLocation = new Path(checkpointLocation, "state") + val storeId = new StateStoreId(stateCheckpointLocation.toString, operatorId, partitionId, + storeName) val manager = new StateSchemaFileManager(storeId, session.sessionState.newHadoopConf()) if (manager.fileExist()) { val (keySchema, valueSchema) = manager.readSchema() @@ -82,6 +87,20 @@ class StateDataSourceV2 extends TableProvider with DataSourceRegister { } } + private def getLastCommittedBatch(checkpointLocation: String): Long = { + val commitLog = new CommitLog(session, new Path(checkpointLocation, "commits").toString) + val lastCommittedBatchId = commitLog.getLatest() match { + case Some((lastId, _)) => lastId + case None => -1 + } + + if (lastCommittedBatchId < 0) { + throw new AnalysisException("No committed batch found.") + } + + lastCommittedBatchId.toLong + 1 + } + override def supportsExternalMetadata(): Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 91fa102387c00..9c20d8e0ca2ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -37,7 +37,7 @@ class StatePartitionReader( partition.operatorId, partition.partition, partition.storeName) val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) - val store = StateStore.get(stateStoreProviderId, keySchema, valueSchema, + val store = StateStore.getReadOnly(stateStoreProviderId, keySchema, valueSchema, indexOrdinal = None, version = partition.version, storeConf = storeConf, hadoopConf = hadoopConf.value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala index 79688b1fe685f..97534625f7763 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -29,8 +29,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap class StateTable( session: SparkSession, override val schema: StructType, - checkpointLocation: String, - version: Int, + stateCheckpointLocation: String, + version: Long, operatorId: Int, storeName: String) extends Table with SupportsRead { @@ -43,15 +43,15 @@ class StateTable( } override def name(): String = - s"state-table-cp-$checkpointLocation-ver-$version-operator-$operatorId-store-$storeName" + s"state-table-cp-$stateCheckpointLocation-ver-$version-operator-$operatorId-store-$storeName" override def capabilities(): util.Set[TableCapability] = CAPABILITY override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = - new StateScanBuilder(session, schema, checkpointLocation, version, operatorId, storeName) + new StateScanBuilder(session, schema, stateCheckpointLocation, version, operatorId, storeName) override def properties(): util.Map[String, String] = Map( - "checkpointLocation" -> checkpointLocation, + "stateCheckpointLocation" -> stateCheckpointLocation, "version" -> version.toString, "operatorId" -> operatorId.toString, "storeName" -> storeName).asJava diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala index fe2e61d2deb77..d4bdc61cfd950 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.sources.v2.state -import java.io.File - import org.scalatest.{Assertions, BeforeAndAfterAll} -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.internal.SQLConf @@ -74,12 +72,12 @@ class StateDataSourceV2ReadSuite runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) val operatorId = 0 - val batchId = 1 + val batchId = 2 val stateReader = spark.read .format("state") - .option(StateDataSourceV2.PARAM_CHECKPOINT_LOCATION, - new File(tempDir, "state").getAbsolutePath) + .option(StateDataSourceV2.PARAM_CHECKPOINT_LOCATION, tempDir.getAbsolutePath) + // explicitly specifying version and operator ID to test out the functionality .option(StateDataSourceV2.PARAM_VERSION, batchId + 1) .option(StateDataSourceV2.PARAM_OPERATOR_ID, operatorId) @@ -105,9 +103,9 @@ class StateDataSourceV2ReadSuite checkAnswer( resultDf, Seq( - Row(0, 4, 60, 30, 0), // 0, 10, 20, 30 - Row(1, 4, 64, 31, 1), // 1, 11, 21, 31 - Row(2, 4, 68, 32, 2), // 2, 12, 22, 32 + Row(0, 5, 60, 30, 0), // 0, 10, 20, 30 + Row(1, 5, 65, 31, 1), // 1, 11, 21, 31 + Row(2, 5, 70, 32, 2), // 2, 12, 22, 32 Row(3, 4, 72, 33, 3), // 3, 13, 23, 33 Row(4, 4, 76, 34, 4), // 4, 14, 24, 34 Row(5, 4, 80, 35, 5), // 5, 15, 25, 35 @@ -128,15 +126,10 @@ class StateDataSourceV2ReadSuite withTempDir { tempDir => runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) - val operatorId = 0 - val batchId = 1 - val stateReader = spark.read .format("state") - .option(StateDataSourceV2.PARAM_CHECKPOINT_LOCATION, - new File(tempDir, "state").getAbsolutePath) - .option(StateDataSourceV2.PARAM_VERSION, batchId + 1) - .option(StateDataSourceV2.PARAM_OPERATOR_ID, operatorId) + .option(StateDataSourceV2.PARAM_CHECKPOINT_LOCATION, tempDir.getAbsolutePath) + // skip version and operator ID to test out functionalities val stateReadDf = if (inferSchema) { stateReader.load() @@ -163,9 +156,9 @@ class StateDataSourceV2ReadSuite resultDf, Seq( Row(0, "Apple", 2, 6, 6, 0), - Row(1, "Banana", 2, 8, 7, 1), - Row(0, "Strawberry", 2, 10, 8, 2), - Row(1, "Apple", 2, 12, 9, 3), + Row(1, "Banana", 3, 9, 7, 1), + Row(0, "Strawberry", 3, 12, 8, 2), + Row(1, "Apple", 3, 15, 9, 3), Row(0, "Banana", 2, 14, 10, 4), Row(1, "Strawberry", 1, 5, 5, 5) ) @@ -174,6 +167,76 @@ class StateDataSourceV2ReadSuite } } - // FIXME: add flatMapGroupsWithState test cases + test("flatMapGroupsWithState, state ver 1, infer schema = false") { + testFlatMapGroupsWithState(1, inferSchema = false) + } + + test("flatMapGroupsWithState, state ver 1, infer schema = true") { + testFlatMapGroupsWithState(1, inferSchema = true) + } + + test("flatMapGroupsWithState, state ver 2, infer schema = false") { + testFlatMapGroupsWithState(2, inferSchema = false) + } + + test("flatMapGroupsWithState, state ver 2, infer schema = true") { + testFlatMapGroupsWithState(2, inferSchema = true) + } + + private def testFlatMapGroupsWithState(stateVersion: Int, inferSchema: Boolean): Unit = { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> stateVersion.toString) { + withTempDir { tempDir => + runFlatMapGroupsWithStateQuery(tempDir.getAbsolutePath) + + val stateReader = spark.read + .format("state") + .option(StateDataSourceV2.PARAM_CHECKPOINT_LOCATION, tempDir.getAbsolutePath) + val stateReadDf = if (inferSchema) { + stateReader.load() + } else { + val stateSchema = getSchemaForFlatMapGroupsWithStateQuery(stateVersion) + stateReader.schema(stateSchema).load() + } + + val resultDf = if (stateVersion == 1) { + stateReadDf + .selectExpr("key.value AS key_value", "value.numEvents AS value_numEvents", + "value.startTimestampMs AS value_startTimestampMs", + "value.endTimestampMs AS value_endTimestampMs", + "value.timeoutTimestamp AS value_timeoutTimestamp") + } else { // stateVersion == 2 + stateReadDf + .selectExpr("key.value AS key_value", "value.groupState.numEvents AS value_numEvents", + "value.groupState.startTimestampMs AS value_startTimestampMs", + "value.groupState.endTimestampMs AS value_endTimestampMs", + "value.timeoutTimestamp AS value_timeoutTimestamp") + } + + checkAnswer( + resultDf, + Seq( + Row("hello", 4, 1000, 4000, 12000), + Row("world", 2, 1000, 3000, 12000), + Row("scala", 2, 2000, 4000, 12000) + ) + ) + + // try to read the value via case class provided in actual query + implicit val encoder = Encoders.product[SessionInfo] + val df = if (stateVersion == 1) { + stateReadDf.selectExpr("value.*").drop("timeoutTimestamp").as[SessionInfo] + } else { // state version == 2 + stateReadDf.selectExpr("value.groupState.*").as[SessionInfo] + } + + val expected = Array( + SessionInfo(4, 1000, 4000), + SessionInfo(2, 1000, 3000), + SessionInfo(2, 2000, 4000) + ) + assert(df.collect().toSet === expected.toSet) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateStoreTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateStoreTestBase.scala index 232f3e06d118e..ddc10bff8ddbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateStoreTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateStoreTestBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.sources.v2.state import java.io.File import java.sql.Timestamp -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{Dataset, Encoders} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ @@ -293,6 +293,19 @@ trait StateStoreTestBase extends StreamTest { sessionUpdates.map(si => (si.id, si.numEvents, si.durationMs, si.expired)) } + + protected def getSchemaForFlatMapGroupsWithStateQuery(stateVersion: Int): StructType = { + val keySchema = new StructType().add("value", StringType, nullable = true) + val valueSchema = if (stateVersion == 1) { + Encoders.product[SessionInfo].schema.add("timeoutTimestamp", IntegerType, nullable = false) + } else { // stateVersion == 2 + new StructType() + .add("groupState", Encoders.product[SessionInfo].schema) + .add("timeoutTimestamp", LongType, nullable = false) + } + + new StructType().add("key", keySchema).add("value", valueSchema) + } } case class Event(sessionId: String, timestamp: Timestamp) From 029c6d015f0f78971ddc8fad88249a5c984dd555 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Mon, 14 Dec 2020 13:48:44 +0900 Subject: [PATCH 4/5] Remove CheckpointUtil given it's helpful on writer --- .../datasources/v2/state/CheckpointUtil.scala | 109 ------------------ 1 file changed, 109 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/CheckpointUtil.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/CheckpointUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/CheckpointUtil.scala deleted file mode 100644 index 8c3256cac0d6c..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/CheckpointUtil.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2.state - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileUtil, Path} - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata} - -/** - * Providing features to deal with checkpoint, like creating savepoint. - */ -object CheckpointUtil { - - /** - * Create savepoint from existing checkpoint. - * OffsetLog and CommitLog will be purged based on newLastBatchId. - * Use `additionalMetadataConf` to modify metadata configuration: you may want to modify it - * when rescaling state, or migrate state format version. - * e.g. when rescaling, pass Map(SQLConf.SHUFFLE_PARTITIONS.key -> newShufflePartitions.toString) - * - * @param sparkSession spark session - * @param checkpointRoot the root path of existing checkpoint - * @param newCheckpointRoot the root path of new savepoint - target directory should be empty - * @param newLastBatchId the new last batch ID - it needs to be one of committed batch ID - * @param additionalMetadataConf the configuration to add to existing metadata configuration - * @param excludeState whether to exclude state directory - */ - def createSavePoint( - sparkSession: SparkSession, - checkpointRoot: String, - newCheckpointRoot: String, - newLastBatchId: Long, - additionalMetadataConf: Map[String, String], - excludeState: Boolean = false): Unit = { - val hadoopConf = sparkSession.sessionState.newHadoopConf() - - val src = new Path(resolve(hadoopConf, checkpointRoot)) - val srcFs = src.getFileSystem(hadoopConf) - val dst = new Path(resolve(hadoopConf, newCheckpointRoot)) - val dstFs = dst.getFileSystem(hadoopConf) - - if (dstFs.listFiles(dst, false).hasNext) { - throw new IllegalArgumentException("Destination directory should be empty.") - } - - dstFs.mkdirs(dst) - - // copy content of src directory to dst directory - srcFs.listStatus(src).foreach { fs => - val path = fs.getPath - val fileName = path.getName - if (fileName == "state" && excludeState) { - // pass - } else { - FileUtil.copy(srcFs, path, dstFs, new Path(dst, fileName), - false, false, hadoopConf) - } - } - - val offsetLog = new OffsetSeqLog(sparkSession, new Path(dst, "offsets").toString) - val logForBatch = offsetLog.get(newLastBatchId) match { - case Some(log) => log - case None => throw new IllegalStateException("offset log for batch should be exist") - } - - val newMetadata = logForBatch.metadata match { - case Some(md) => - val newMap = md.conf ++ additionalMetadataConf - Some(md.copy(conf = newMap)) - case None => - Some(OffsetSeqMetadata(conf = additionalMetadataConf)) - } - - val newLogForBatch = logForBatch.copy(metadata = newMetadata) - - // we will restart from last batch + 1: overwrite the last batch with new configuration - offsetLog.purgeAfter(newLastBatchId - 1) - offsetLog.add(newLastBatchId, newLogForBatch) - - val commitLog = new CommitLog(sparkSession, new Path(dst, "commits").toString) - commitLog.purgeAfter(newLastBatchId) - - // state doesn't expose purge mechanism as its interface - // assuming state would work with overwriting batch files when it replays previous batch - } - - private def resolve(hadoopConf: Configuration, cpLocation: String): String = { - val checkpointPath = new Path(cpLocation) - val fs = checkpointPath.getFileSystem(hadoopConf) - checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString - } -} From a495f6d56411f2f3bb1e271babe9efad008b3959 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Mon, 14 Dec 2020 13:53:36 +0900 Subject: [PATCH 5/5] Remove schema extractor as SPARK-27237 enables to read schema without investigating query --- .../v2/state/StateSchemaExtractor.scala | 87 -------------- .../v2/state/StateSchemaExtractorSuite.scala | 109 ------------------ 2 files changed, 196 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateSchemaExtractor.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateSchemaExtractorSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateSchemaExtractor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateSchemaExtractor.scala deleted file mode 100644 index b470b377b4192..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateSchemaExtractor.scala +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.state - -import java.util.UUID - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.execution.datasources.v2.state.StateSchemaExtractor.{StateKind, StateSchemaInfo} -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.StructType - -/** - * This class enables extracting state schema and its format version via analyzing - * the streaming query. The query should have its state operators but it should exclude sink(s). - * - * Note that it only returns which can be extracted by this class, so number of state - * in given query may not be same as returned number of schema information. - */ -class StateSchemaExtractor(spark: SparkSession) extends Logging { - - def extract(query: DataFrame): Seq[StateSchemaInfo] = { - require(query.isStreaming, "Given query is not a streaming query!") - - val queryExecution = new IncrementalExecution(spark, query.logicalPlan, - OutputMode.Update(), "", UUID.randomUUID(), UUID.randomUUID(), - 0, OffsetSeqMetadata()) - - // TODO: handle Streaming Join (if possible), etc. - queryExecution.executedPlan.collect { - case store: StateStoreSaveExec => - val stateFormatVersion = store.stateFormatVersion - val keySchema = store.keyExpressions.toStructType - val valueSchema = store.stateManager.getStateValueSchema - store.stateInfo match { - case Some(stInfo) => - val operatorId = stInfo.operatorId - StateSchemaInfo(operatorId, StateKind.StreamingAggregation, - stateFormatVersion, keySchema, valueSchema) - - case None => throw new IllegalStateException("State information not set!") - } - - case store: FlatMapGroupsWithStateExec => - val stateFormatVersion = store.stateFormatVersion - val keySchema = store.groupingAttributes.toStructType - val valueSchema = store.stateManager.stateSchema - store.stateInfo match { - case Some(stInfo) => - val operatorId = stInfo.operatorId - StateSchemaInfo(operatorId, StateKind.FlatMapGroupsWithState, - stateFormatVersion, keySchema, valueSchema) - - case None => throw new IllegalStateException("State information not set!") - } - } - } - -} - -object StateSchemaExtractor { - object StateKind extends Enumeration { - val StreamingAggregation, StreamingJoin, FlatMapGroupsWithState = Value - } - - case class StateSchemaInfo( - opId: Long, - stateKind: StateKind.Value, - formatVersion: Int, - keySchema: StructType, - valueSchema: StructType) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateSchemaExtractorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateSchemaExtractorSuite.scala deleted file mode 100644 index fe28eb121509a..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateSchemaExtractorSuite.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.state - -import org.scalatest.{Assertions, BeforeAndAfterAll} - -import org.apache.spark.sql.Encoders -import org.apache.spark.sql.execution.datasources.v2.state.{SchemaUtil, StateSchemaExtractor} -import org.apache.spark.sql.execution.datasources.v2.state.StateSchemaExtractor.StateKind -import org.apache.spark.sql.execution.streaming.state.StateStore -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} - -class StateSchemaExtractorSuite - extends StateStoreTestBase - with BeforeAndAfterAll - with Assertions { - - override def afterAll(): Unit = { - super.afterAll() - StateStore.stop() - } - - Seq(1, 2).foreach { ver => - test(s"extract schema from streaming aggregation query - state format v$ver") { - withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> ver.toString) { - val aggregated = getCompositeKeyStreamingAggregationQuery - - val stateSchema = getSchemaForCompositeKeyStreamingAggregationQuery(ver) - val expectedKeySchema = SchemaUtil.getSchemaAsDataType(stateSchema, "key") - .asInstanceOf[StructType] - val expectedValueSchema = SchemaUtil.getSchemaAsDataType(stateSchema, "value") - .asInstanceOf[StructType] - - val schemaInfos = new StateSchemaExtractor(spark).extract(aggregated.toDF()) - assert(schemaInfos.length === 1) - val schemaInfo = schemaInfos.head - assert(schemaInfo.opId === 0) - assert(schemaInfo.formatVersion === ver) - assert(schemaInfo.stateKind === StateKind.StreamingAggregation) - - assert(compareSchemaWithoutName(schemaInfo.keySchema, expectedKeySchema), - s"Even without column names, ${schemaInfo.keySchema} did not equal $expectedKeySchema") - assert(compareSchemaWithoutName(schemaInfo.valueSchema, expectedValueSchema), - s"Even without column names, ${schemaInfo.valueSchema} did not equal " + - s"$expectedValueSchema") - } - } - } - - Seq(1, 2).foreach { ver => - test(s"extract schema from flatMapGroupsWithState query - state format v$ver") { - withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> ver.toString) { - // This is borrowed from StateStoreTest, runFlatMapGroupsWithStateQuery - val aggregated = getFlatMapGroupsWithStateQuery - - val expectedKeySchema = new StructType().add("value", StringType, nullable = true) - - val expectedValueSchema = if (ver == 1) { - Encoders.product[SessionInfo].schema - .add("timeoutTimestamp", IntegerType, nullable = false) - } else { - // ver == 2 - new StructType() - .add("groupState", Encoders.product[SessionInfo].schema) - .add("timeoutTimestamp", LongType, nullable = false) - } - - val schemaInfos = new StateSchemaExtractor(spark).extract(aggregated.toDF()) - assert(schemaInfos.length === 1) - val schemaInfo = schemaInfos.head - assert(schemaInfo.opId === 0) - assert(schemaInfo.stateKind === StateKind.FlatMapGroupsWithState) - assert(schemaInfo.formatVersion === ver) - - assert(compareSchemaWithoutName(schemaInfo.keySchema, expectedKeySchema), - s"Even without column names, ${schemaInfo.keySchema} did not equal $expectedKeySchema") - assert(compareSchemaWithoutName(schemaInfo.valueSchema, expectedValueSchema), - s"Even without column names, ${schemaInfo.valueSchema} did not equal " + - s"$expectedValueSchema") - } - } - } - - private def compareSchemaWithoutName(s1: StructType, s2: StructType): Boolean = { - if (s1.length != s2.length) { - false - } else { - s1.zip(s2).forall { case (column1, column2) => - column1.dataType == column2.dataType && column1.nullable == column2.nullable - } - } - } -}