diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index ae5b5c52d514..cfc8f9ce2838 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter} import org.apache.spark.sql.types.StructType /** @@ -63,7 +63,7 @@ class KafkaStreamWriter( */ case class KafkaStreamWriterFactory( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends DataWriterFactory[InternalRow] { + extends StreamingDataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 39bf45829886..82b7d162808f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is + * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -39,7 +39,7 @@ * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark may retry this writing task a few times. - * In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a + * In each retry, {@link DataWriterFactory#createDataWriter(int, int)} will receive a * different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} * when the configured number of retries is exhausted. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index c2c2ab73257e..ea95442511ce 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -48,9 +48,6 @@ public interface DataWriterFactory extends Serializable { * same task id but different attempt number, which means there are multiple * tasks with the same task id running at the same time. Implementations can * use this attempt number to distinguish writers of different task attempts. - * @param epochId A monotonically increasing id for streaming queries that are split in to - * discrete periods of execution. For non-streaming queries, - * this ID will always be 0. */ - DataWriter createDataWriter(int partitionId, int attemptNumber, long epochId); + DataWriter createDataWriter(int partitionId, int attemptNumber); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java index a316b2a4c1d8..81167c9237e8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2.writer.streaming; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.sources.v2.writer.DataWriter; import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; @@ -27,6 +28,9 @@ * * Streaming queries are divided into intervals of data called epochs, with a monotonically * increasing numeric ID. This writer handles commits and aborts for each successive epoch. + * + * Note that StreamWriter implementations should provide instances of + * {@link StreamingDataWriterFactory}. */ @InterfaceStability.Evolving public interface StreamWriter extends DataSourceWriter { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java new file mode 100644 index 000000000000..849887a49496 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.writer.DataWriter; +import org.apache.spark.sql.sources.v2.writer.DataWriterFactory; + +@InterfaceStability.Evolving +public interface StreamingDataWriterFactory extends DataWriterFactory { + /** + * Returns a data writer to do the actual writing work. + * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. + * + * @param partitionId A unique id of the RDD partition that the returned writer will process. + * Usually Spark processes many RDD partitions at the same time, + * implementations should use the partition id to distinguish writers for + * different partitions. + * @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task + * failed, Spark launches a new task wth the same task id but different + * attempt number. Or a task is too slow, Spark launches new tasks wth the + * same task id but different attempt number, which means there are multiple + * tasks with the same task id running at the same time. Implementations can + * use this attempt number to distinguish writers of different task attempts. + * @param epochId A monotonically increasing id for streaming queries that are split in to + * discrete periods of execution. For non-streaming queries, + * this ID will always be 0. + */ + DataWriter createDataWriter(int partitionId, int attemptNumber, long epochId); + + @Override default DataWriter createDataWriter(int partitionId, int attemptNumber) { + throw new IllegalStateException("Streaming data writer factory cannot create data writers without epoch."); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index e80b44c1cdc6..b3f73a0f7b32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution} import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} +import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -54,7 +55,14 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e override protected def doExecute(): RDD[InternalRow] = { val writeTask = writer match { case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) + case w: MicroBatchWriter => + new StreamingInternalRowDataWriterFactory(w.createWriterFactory(), query.schema) + case w: StreamWriter => + new StreamingInternalRowDataWriterFactory( + w.createWriterFactory().asInstanceOf[StreamingDataWriterFactory[Row]], + query.schema) + case _ => + new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) } val useCommitCoordinator = writer.useCommitCoordinator @@ -75,7 +83,8 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.runContinuous(writeTask, context, iter) + DataWritingSparkTask.runContinuous( + writeTask.asInstanceOf[StreamingDataWriterFactory[InternalRow]], context, iter) case _ => (context: TaskContext, iter: Iterator[InternalRow]) => DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator) @@ -132,8 +141,13 @@ object DataWritingSparkTask extends Logging { val stageId = context.stageId() val partId = context.partitionId() val attemptId = context.attemptNumber() - val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") - val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong) + val dataWriter = writeTask match { + case w: StreamingDataWriterFactory[InternalRow] => + val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).get + w.createDataWriter(partId, attemptId, epochId.toLong) + + case w => w.createDataWriter(partId, attemptId) + } // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { @@ -170,7 +184,7 @@ object DataWritingSparkTask extends Logging { } def runContinuous( - writeTask: DataWriterFactory[InternalRow], + writeTask: StreamingDataWriterFactory[InternalRow], context: TaskContext, iter: Iterator[InternalRow]): WriterCommitMessage = { val epochCoordinator = EpochCoordinatorRef.get( @@ -217,6 +231,17 @@ class InternalRowDataWriterFactory( rowWriterFactory: DataWriterFactory[Row], schema: StructType) extends DataWriterFactory[InternalRow] { + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new InternalRowDataWriter( + rowWriterFactory.createDataWriter(partitionId, attemptNumber), + RowEncoder.apply(schema).resolveAndBind()) + } +} + +class StreamingInternalRowDataWriterFactory( + rowWriterFactory: StreamingDataWriterFactory[Row], + schema: StructType) extends StreamingDataWriterFactory[InternalRow] { + override def createDataWriter( partitionId: Int, attemptNumber: Int, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index d276403190b3..03b4cb06425c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -22,8 +22,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter} import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ @@ -39,7 +39,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) assert(SparkSession.getActiveSession.isDefined) protected val spark = SparkSession.getActiveSession.get - def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + def createWriterFactory(): StreamingDataWriterFactory[Row] = PackedRowWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala index 56f7ff25cbed..e2ae79c78ba8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter} /** * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements @@ -34,7 +34,13 @@ class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWr override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) - override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory() + override def createWriterFactory(): StreamingDataWriterFactory[Row] = { + writer.createWriterFactory() match { + case s: StreamingDataWriterFactory[Row] => s + case _ => + throw new IllegalStateException("StreamWriter did not give a StreamingDataWriterFactory") + } + } } class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index e07355aa37db..c7fd323b5907 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory /** * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery @@ -30,7 +31,7 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat * Note that, because it sends all rows to the driver, this factory will generally be unsuitable * for production-quality sinks. It's intended for use in tests. */ -case object PackedRowWriterFactory extends DataWriterFactory[Row] { +case object PackedRowWriterFactory extends StreamingDataWriterFactory[Row] { override def createDataWriter( partitionId: Int, attemptNumber: Int, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 5f58246083bb..3872f3a9f384 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Comp import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -146,7 +146,7 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) } } -case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { +case class MemoryWriterFactory(outputMode: OutputMode) extends StreamingDataWriterFactory[Row] { override def createDataWriter( partitionId: Int, attemptNumber: Int, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a5007fa32135..36dd2a350a05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -207,10 +207,7 @@ private[v2] object SimpleCounter { class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[Row] { - override def createDataWriter( - partitionId: Int, - attemptNumber: Int, - epochId: Long): DataWriter[Row] = { + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) @@ -243,10 +240,7 @@ class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[InternalRow] { - override def createDataWriter( - partitionId: Int, - attemptNumber: Int, - epochId: Long): DataWriter[InternalRow] = { + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value)