diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 616fc72320ca..6a78b9e2bddd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.write._ +import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull} import org.apache.spark.sql.types.{DataType, DateType, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -145,6 +146,7 @@ class InMemoryTable( override def capabilities: util.Set[TableCapability] = Set( TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, + TableCapability.STREAMING_WRITE, TableCapability.OVERWRITE_BY_FILTER, TableCapability.OVERWRITE_DYNAMIC, TableCapability.TRUNCATE).asJava @@ -169,26 +171,35 @@ class InMemoryTable( new WriteBuilder with SupportsTruncate with SupportsOverwrite with SupportsDynamicOverwrite { private var writer: BatchWrite = Append + private var streamingWriter: StreamingWrite = StreamingAppend override def truncate(): WriteBuilder = { assert(writer == Append) writer = TruncateAndAppend + streamingWriter = StreamingTruncateAndAppend this } override def overwrite(filters: Array[Filter]): WriteBuilder = { assert(writer == Append) writer = new Overwrite(filters) + streamingWriter = new StreamingNotSupportedOperation(s"overwrite ($filters)") this } override def overwriteDynamicPartitions(): WriteBuilder = { assert(writer == Append) writer = DynamicOverwrite + streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions") this } override def buildForBatch(): BatchWrite = writer + + override def buildForStreaming(): StreamingWrite = streamingWriter match { + case exc: StreamingNotSupportedOperation => exc.throwsException() + case s => s + } } } @@ -231,6 +242,45 @@ class InMemoryTable( } } + private abstract class TestStreamingWrite extends StreamingWrite { + def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = { + BufferedRowsWriterFactory + } + + def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + } + + private class StreamingNotSupportedOperation(operation: String) extends TestStreamingWrite { + override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = + throwsException() + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = + throwsException() + + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = + throwsException() + + def throwsException[T](): T = throw new IllegalStateException("The operation " + + s"${operation} isn't supported for streaming query.") + } + + private object StreamingAppend extends TestStreamingWrite { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + dataMap.synchronized { + withData(messages.map(_.asInstanceOf[BufferedRows])) + } + } + } + + private object StreamingTruncateAndAppend extends TestStreamingWrite { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + dataMap.synchronized { + dataMap.clear + withData(messages.map(_.asInstanceOf[BufferedRows])) + } + } + } + override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters) @@ -310,10 +360,17 @@ private class BufferedRowsReader(partition: BufferedRows) extends PartitionReade override def close(): Unit = {} } -private object BufferedRowsWriterFactory extends DataWriterFactory { +private object BufferedRowsWriterFactory extends DataWriterFactory with StreamingDataWriterFactory { override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { new BufferWriter } + + override def createWriter( + partitionId: Int, + taskId: Long, + epochId: Long): DataWriter[InternalRow] = { + new BufferWriter + } } private class BufferWriter extends DataWriter[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index dda6dec9c4eb..239b4fc2de37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -27,7 +27,7 @@ import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{SupportsWrite, TableProvider} +import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource @@ -45,6 +45,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap */ @Evolving final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { + import DataStreamWriter._ private val df = ds.toDF() @@ -294,60 +295,75 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { @throws[TimeoutException] def start(): StreamingQuery = startInternal(None) + /** + * Starts the execution of the streaming query, which will continually output results to the given + * table as new data arrives. The returned [[StreamingQuery]] object can be used to interact with + * the stream. + * + * @since 3.1.0 + */ + @throws[TimeoutException] + def saveAsTable(tableName: String): StreamingQuery = { + this.source = SOURCE_NAME_TABLE + this.tableName = tableName + startInternal(None) + } + private def startInternal(path: Option[String]): StreamingQuery = { if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "write files of Hive data source directly.") } - if (source == "memory") { - assertNotPartitioned("memory") + if (source == SOURCE_NAME_TABLE) { + assertNotPartitioned(SOURCE_NAME_TABLE) + + import df.sparkSession.sessionState.analyzer.CatalogAndIdentifier + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + val originalMultipartIdentifier = df.sparkSession.sessionState.sqlParser + .parseMultipartIdentifier(tableName) + val CatalogAndIdentifier(catalog, identifier) = originalMultipartIdentifier + + // Currently we don't create a logical streaming writer node in logical plan, so cannot rely + // on analyzer to resolve it. Directly lookup only for temp view to provide clearer message. + // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. + if (df.sparkSession.sessionState.catalog.isTempView(originalMultipartIdentifier)) { + throw new AnalysisException(s"Temporary view $tableName doesn't support streaming write") + } + + val tableInstance = catalog.asTableCatalog.loadTable(identifier) + + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ + val sink = tableInstance match { + case t: SupportsWrite if t.supports(STREAMING_WRITE) => t + case t => throw new AnalysisException(s"Table $tableName doesn't support streaming " + + s"write - $t") + } + + startQuery(sink, extraOptions) + } else if (source == SOURCE_NAME_MEMORY) { + assertNotPartitioned(SOURCE_NAME_MEMORY) if (extraOptions.get("queryName").isEmpty) { throw new AnalysisException("queryName must be specified for memory sink") } val sink = new MemorySink() val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink, df.schema.toAttributes)) - val chkpointLoc = extraOptions.get("checkpointLocation") val recoverFromChkpoint = outputMode == OutputMode.Complete() - val query = df.sparkSession.sessionState.streamingQueryManager.startQuery( - extraOptions.get("queryName"), - chkpointLoc, - df, - extraOptions.toMap, - sink, - outputMode, - useTempCheckpointLocation = true, - recoverFromCheckpointLocation = recoverFromChkpoint, - trigger = trigger) + val query = startQuery(sink, extraOptions, recoverFromCheckpoint = recoverFromChkpoint) resultDf.createOrReplaceTempView(query.name) query - } else if (source == "foreach") { - assertNotPartitioned("foreach") + } else if (source == SOURCE_NAME_FOREACH) { + assertNotPartitioned(SOURCE_NAME_FOREACH) val sink = ForeachWriterTable[T](foreachWriter, ds.exprEnc) - df.sparkSession.sessionState.streamingQueryManager.startQuery( - extraOptions.get("queryName"), - extraOptions.get("checkpointLocation"), - df, - extraOptions.toMap, - sink, - outputMode, - useTempCheckpointLocation = true, - trigger = trigger) - } else if (source == "foreachBatch") { - assertNotPartitioned("foreachBatch") + startQuery(sink, extraOptions) + } else if (source == SOURCE_NAME_FOREACH_BATCH) { + assertNotPartitioned(SOURCE_NAME_FOREACH_BATCH) if (trigger.isInstanceOf[ContinuousTrigger]) { - throw new AnalysisException("'foreachBatch' is not supported with continuous trigger") + throw new AnalysisException(s"'$source' is not supported with continuous trigger") } val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc) - df.sparkSession.sessionState.streamingQueryManager.startQuery( - extraOptions.get("queryName"), - extraOptions.get("checkpointLocation"), - df, - extraOptions.toMap, - sink, - outputMode, - useTempCheckpointLocation = true, - trigger = trigger) + startQuery(sink, extraOptions) } else { val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") @@ -380,19 +396,28 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { createV1Sink(optionsWithPath) } - df.sparkSession.sessionState.streamingQueryManager.startQuery( - extraOptions.get("queryName"), - extraOptions.get("checkpointLocation"), - df, - optionsWithPath.originalMap, - sink, - outputMode, - useTempCheckpointLocation = source == "console" || source == "noop", - recoverFromCheckpointLocation = true, - trigger = trigger) + startQuery(sink, optionsWithPath) } } + private def startQuery( + sink: Table, + newOptions: CaseInsensitiveMap[String], + recoverFromCheckpoint: Boolean = true): StreamingQuery = { + val useTempCheckpointLocation = SOURCES_ALLOW_ONE_TIME_QUERY.contains(source) + + df.sparkSession.sessionState.streamingQueryManager.startQuery( + newOptions.get("queryName"), + newOptions.get("checkpointLocation"), + df, + newOptions.originalMap, + sink, + outputMode, + useTempCheckpointLocation = useTempCheckpointLocation, + recoverFromCheckpointLocation = recoverFromCheckpoint, + trigger = trigger) + } + private def createV1Sink(optionsWithPath: CaseInsensitiveMap[String]): Sink = { val ds = DataSource( df.sparkSession, @@ -409,7 +434,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * @since 2.0.0 */ def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { - this.source = "foreach" + this.source = SOURCE_NAME_FOREACH this.foreachWriter = if (writer != null) { ds.sparkSession.sparkContext.clean(writer) } else { @@ -433,7 +458,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { */ @Evolving def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { - this.source = "foreachBatch" + this.source = SOURCE_NAME_FOREACH_BATCH if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") this.foreachBatchWriter = function this @@ -485,6 +510,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName + private var tableName: String = null + private var outputMode: OutputMode = OutputMode.Append private var trigger: Trigger = Trigger.ProcessingTime(0L) @@ -497,3 +524,16 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private var partitioningColumns: Option[Seq[String]] = None } + +object DataStreamWriter { + val SOURCE_NAME_MEMORY = "memory" + val SOURCE_NAME_FOREACH = "foreach" + val SOURCE_NAME_FOREACH_BATCH = "foreachBatch" + val SOURCE_NAME_CONSOLE = "console" + val SOURCE_NAME_TABLE = "table" + val SOURCE_NAME_NOOP = "noop" + + // these writer sources are also used for one-time query, hence allow temp checkpoint location + val SOURCES_ALLOW_ONE_TIME_QUERY = Seq(SOURCE_NAME_MEMORY, SOURCE_NAME_FOREACH, + SOURCE_NAME_FOREACH_BATCH, SOURCE_NAME_CONSOLE, SOURCE_NAME_NOOP) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala index 788452dace84..062b1060bc60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.streaming.test +import java.io.File import java.util import scala.collection.JavaConverters._ @@ -25,10 +26,10 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 -import org.apache.spark.sql.connector.{FakeV2Provider, InMemoryTableCatalog} +import org.apache.spark.sql.connector.{FakeV2Provider, InMemoryTableCatalog, InMemoryTableSessionCatalog} import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, Table, TableCapability, V2TableWithV1Fallback} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.ScanBuilder @@ -51,9 +52,10 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { after { spark.sessionState.catalogManager.reset() spark.sessionState.conf.clear() + sqlContext.streams.active.foreach(_.stop()) } - test("table API with file source") { + test("read: table API with file source") { Seq("parquet", "").foreach { source => withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> source) { withTempDir { tempDir => @@ -72,13 +74,13 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { } } - test("read non-exist table") { + test("read: read non-exist table") { intercept[AnalysisException] { spark.readStream.table("non_exist_table") }.message.contains("Table not found") } - test("stream table API with temp view") { + test("read: stream table API with temp view") { val tblName = "my_table" val stream = MemoryStream[Int] withTable(tblName) { @@ -93,7 +95,7 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { } } - test("stream table API with non-streaming temp view") { + test("read: stream table API with non-streaming temp view") { val tblName = "my_table" withTable(tblName) { spark.range(3).createOrReplaceTempView(tblName) @@ -103,7 +105,7 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { } } - test("read table without streaming capability support") { + test("read: read table without streaming capability support") { val tableIdentifer = "testcat.table_name" spark.sql(s"CREATE TABLE $tableIdentifer (id bigint, data string) USING foo") @@ -113,7 +115,7 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { }.message.contains("does not support either micro-batch or continuous scan") } - test("read table with custom catalog") { + test("read: read table with custom catalog") { val tblName = "teststream.table_name" withTable(tblName) { spark.sql(s"CREATE TABLE $tblName (data int) USING foo") @@ -131,7 +133,7 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { } } - test("read table with custom catalog & namespace") { + test("read: read table with custom catalog & namespace") { spark.sql("CREATE NAMESPACE teststream.ns") val tblName = "teststream.ns.table_name" @@ -151,7 +153,7 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { } } - test("fallback to V1 relation") { + test("read: fallback to V1 relation") { val tblName = DataStreamTableAPISuite.V1FallbackTestTableName spark.conf.set(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryStreamTableCatalog].getName) @@ -169,6 +171,146 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { } } } + + test("write: write to table with custom catalog & no namespace") { + val tableIdentifier = "testcat.table_name" + + spark.sql(s"CREATE TABLE $tableIdentifier (id bigint, data string) USING foo") + checkAnswer(spark.table(tableIdentifier), Seq.empty) + + runTestWithStreamAppend(tableIdentifier) + } + + test("write: write to table with custom catalog & namespace") { + spark.sql("CREATE NAMESPACE testcat.ns") + + val tableIdentifier = "testcat.ns.table_name" + + spark.sql(s"CREATE TABLE $tableIdentifier (id bigint, data string) USING foo") + checkAnswer(spark.table(tableIdentifier), Seq.empty) + + runTestWithStreamAppend(tableIdentifier) + } + + test("write: write to table with default session catalog") { + val v2Source = classOf[FakeV2Provider].getName + spark.conf.set(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION.key, + classOf[InMemoryTableSessionCatalog].getName) + + spark.sql("CREATE NAMESPACE ns") + + val tableIdentifier = "ns.table_name" + spark.sql(s"CREATE TABLE $tableIdentifier (id bigint, data string) USING $v2Source") + checkAnswer(spark.table(tableIdentifier), Seq.empty) + + runTestWithStreamAppend(tableIdentifier) + } + + test("write: write to non-exist table with custom catalog") { + val tableIdentifier = "testcat.nonexisttable" + spark.sql("CREATE NAMESPACE testcat.ns") + + withTempDir { checkpointDir => + val exc = intercept[NoSuchTableException] { + runStreamQueryAppendMode(tableIdentifier, checkpointDir, Seq.empty, Seq.empty) + } + assert(exc.getMessage.contains("nonexisttable")) + } + } + + test("write: write to file provider based table isn't allowed yet") { + val tableIdentifier = "table_name" + + spark.sql(s"CREATE TABLE $tableIdentifier (id bigint, data string) USING parquet") + checkAnswer(spark.table(tableIdentifier), Seq.empty) + + withTempDir { checkpointDir => + val exc = intercept[AnalysisException] { + runStreamQueryAppendMode(tableIdentifier, checkpointDir, Seq.empty, Seq.empty) + } + assert(exc.getMessage.contains("doesn't support streaming write")) + } + } + + test("write: write to temporary view isn't allowed yet") { + val tableIdentifier = "testcat.table_name" + val tempViewIdentifier = "temp_view" + + spark.sql(s"CREATE TABLE $tableIdentifier (id bigint, data string) USING foo") + checkAnswer(spark.table(tableIdentifier), Seq.empty) + + spark.table(tableIdentifier).createOrReplaceTempView(tempViewIdentifier) + + withTempDir { checkpointDir => + val exc = intercept[AnalysisException] { + runStreamQueryAppendMode(tempViewIdentifier, checkpointDir, Seq.empty, Seq.empty) + } + assert(exc.getMessage.contains("doesn't support streaming write")) + } + } + + test("write: write to view shouldn't be allowed") { + val tableIdentifier = "testcat.table_name" + val viewIdentifier = "table_view" + + spark.sql(s"CREATE TABLE $tableIdentifier (id bigint, data string) USING foo") + checkAnswer(spark.table(tableIdentifier), Seq.empty) + + spark.sql(s"CREATE VIEW $viewIdentifier AS SELECT id, data FROM $tableIdentifier") + + withTempDir { checkpointDir => + val exc = intercept[AnalysisException] { + runStreamQueryAppendMode(viewIdentifier, checkpointDir, Seq.empty, Seq.empty) + } + assert(exc.getMessage.contains("doesn't support streaming write")) + } + } + + private def runTestWithStreamAppend(tableIdentifier: String) = { + withTempDir { checkpointDir => + val input1 = Seq((1L, "a"), (2L, "b"), (3L, "c")) + verifyStreamAppend(tableIdentifier, checkpointDir, Seq.empty, input1, input1) + + val input2 = Seq((4L, "d"), (5L, "e"), (6L, "f")) + verifyStreamAppend(tableIdentifier, checkpointDir, Seq(input1), input2, input1 ++ input2) + } + } + + private def runStreamQueryAppendMode( + tableIdentifier: String, + checkpointDir: File, + prevInputs: Seq[Seq[(Long, String)]], + newInputs: Seq[(Long, String)]): Unit = { + val inputData = MemoryStream[(Long, String)] + val inputDF = inputData.toDF().toDF("id", "data") + + prevInputs.foreach { inputsPerBatch => + inputData.addData(inputsPerBatch: _*) + } + + val query = inputDF + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .saveAsTable(tableIdentifier) + + inputData.addData(newInputs: _*) + + query.processAllAvailable() + query.stop() + } + + private def verifyStreamAppend( + tableIdentifier: String, + checkpointDir: File, + prevInputs: Seq[Seq[(Long, String)]], + newInputs: Seq[(Long, String)], + expectedOutputs: Seq[(Long, String)]): Unit = { + runStreamQueryAppendMode(tableIdentifier, checkpointDir, prevInputs, newInputs) + checkAnswer( + spark.table(tableIdentifier), + expectedOutputs.map { case (id, data) => Row(id, data) } + ) + } } object DataStreamTableAPISuite {