Skip to content

Commit c6d4ff5

Browse files
committed
Create StreamingDataWriterFactory for epoch ID.
1 parent b0f422c commit c6d4ff5

File tree

10 files changed

+93
-28
lines changed

10 files changed

+93
-28
lines changed

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions.Attribute
2424
import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
2525
import org.apache.spark.sql.sources.v2.writer._
26-
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
26+
import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter}
2727
import org.apache.spark.sql.types.StructType
2828

2929
/**
@@ -63,7 +63,7 @@ class KafkaStreamWriter(
6363
*/
6464
case class KafkaStreamWriterFactory(
6565
topic: Option[String], producerParams: Map[String, String], schema: StructType)
66-
extends DataWriterFactory[InternalRow] {
66+
extends StreamingDataWriterFactory[InternalRow] {
6767

6868
override def createDataWriter(
6969
partitionId: Int,

sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,6 @@ public interface DataWriterFactory<T> extends Serializable {
4848
* same task id but different attempt number, which means there are multiple
4949
* tasks with the same task id running at the same time. Implementations can
5050
* use this attempt number to distinguish writers of different task attempts.
51-
* @param epochId A monotonically increasing id for streaming queries that are split in to
52-
* discrete periods of execution. For non-streaming queries,
53-
* this ID will always be 0.
5451
*/
55-
DataWriter<T> createDataWriter(int partitionId, int attemptNumber, long epochId);
52+
DataWriter<T> createDataWriter(int partitionId, int attemptNumber);
5653
}

sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.sources.v2.writer.streaming;
1919

2020
import org.apache.spark.annotation.InterfaceStability;
21+
import org.apache.spark.sql.Row;
2122
import org.apache.spark.sql.sources.v2.writer.DataSourceWriter;
2223
import org.apache.spark.sql.sources.v2.writer.DataWriter;
2324
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage;
@@ -27,6 +28,9 @@
2728
*
2829
* Streaming queries are divided into intervals of data called epochs, with a monotonically
2930
* increasing numeric ID. This writer handles commits and aborts for each successive epoch.
31+
*
32+
* Note that StreamWriter implementations should provide instances of
33+
* {@link StreamingDataWriterFactory}.
3034
*/
3135
@InterfaceStability.Evolving
3236
public interface StreamWriter extends DataSourceWriter {
@@ -59,6 +63,14 @@ public interface StreamWriter extends DataSourceWriter {
5963
*/
6064
void abort(long epochId, WriterCommitMessage[] messages);
6165

66+
/**
67+
* Creates a writer factory which will be serialized and sent to executors.
68+
*
69+
* If this method fails (by throwing an exception), the query will fail and no Spark job will be
70+
* submitted.
71+
*/
72+
StreamingDataWriterFactory<Row> createWriterFactory();
73+
6274
default void commit(WriterCommitMessage[] messages) {
6375
throw new UnsupportedOperationException(
6476
"Commit without epoch should not be called with StreamWriter");
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package org.apache.spark.sql.sources.v2.writer.streaming;
2+
3+
import org.apache.spark.sql.sources.v2.writer.DataWriter;
4+
import org.apache.spark.sql.sources.v2.writer.DataWriterFactory;
5+
6+
public interface StreamingDataWriterFactory<T> extends DataWriterFactory<T> {
7+
/**
8+
* Returns a data writer to do the actual writing work.
9+
*
10+
* If this method fails (by throwing an exception), the action would fail and no Spark job was
11+
* submitted.
12+
*
13+
* @param partitionId A unique id of the RDD partition that the returned writer will process.
14+
* Usually Spark processes many RDD partitions at the same time,
15+
* implementations should use the partition id to distinguish writers for
16+
* different partitions.
17+
* @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task
18+
* failed, Spark launches a new task wth the same task id but different
19+
* attempt number. Or a task is too slow, Spark launches new tasks wth the
20+
* same task id but different attempt number, which means there are multiple
21+
* tasks with the same task id running at the same time. Implementations can
22+
* use this attempt number to distinguish writers of different task attempts.
23+
* @param epochId A monotonically increasing id for streaming queries that are split in to
24+
* discrete periods of execution. For non-streaming queries,
25+
* this ID will always be 0.
26+
*/
27+
DataWriter<T> createDataWriter(int partitionId, int attemptNumber, long epochId);
28+
29+
@Override default DataWriter<T> createDataWriter(int partitionId, int attemptNumber) {
30+
throw new IllegalStateException("Streaming data writer factory cannot create data writers without epoch.");
31+
}
32+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3131
import org.apache.spark.sql.execution.SparkPlan
3232
import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution}
3333
import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions}
34+
import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter
3435
import org.apache.spark.sql.sources.v2.writer._
35-
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
36+
import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter}
3637
import org.apache.spark.sql.types.StructType
3738
import org.apache.spark.util.Utils
3839

@@ -54,7 +55,12 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
5455
override protected def doExecute(): RDD[InternalRow] = {
5556
val writeTask = writer match {
5657
case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
57-
case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
58+
case w: MicroBatchWriter =>
59+
new StreamingInternalRowDataWriterFactory(w.createWriterFactory(), query.schema)
60+
case w: StreamWriter =>
61+
new StreamingInternalRowDataWriterFactory(w.createWriterFactory(), query.schema)
62+
case _ =>
63+
new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
5864
}
5965

6066
val useCommitCoordinator = writer.useCommitCoordinator
@@ -75,7 +81,8 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
7581
.askSync[Unit](SetWriterPartitions(rdd.getNumPartitions))
7682

7783
(context: TaskContext, iter: Iterator[InternalRow]) =>
78-
DataWritingSparkTask.runContinuous(writeTask, context, iter)
84+
DataWritingSparkTask.runContinuous(
85+
writeTask.asInstanceOf[StreamingDataWriterFactory[InternalRow]], context, iter)
7986
case _ =>
8087
(context: TaskContext, iter: Iterator[InternalRow]) =>
8188
DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator)
@@ -132,8 +139,13 @@ object DataWritingSparkTask extends Logging {
132139
val stageId = context.stageId()
133140
val partId = context.partitionId()
134141
val attemptId = context.attemptNumber()
135-
val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0")
136-
val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong)
142+
val dataWriter = writeTask match {
143+
case w: StreamingDataWriterFactory[InternalRow] =>
144+
val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).get
145+
w.createDataWriter(partId, attemptId, epochId.toLong)
146+
147+
case w => w.createDataWriter(partId, attemptId)
148+
}
137149

138150
// write the data and commit this writer.
139151
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
@@ -170,7 +182,7 @@ object DataWritingSparkTask extends Logging {
170182
}
171183

172184
def runContinuous(
173-
writeTask: DataWriterFactory[InternalRow],
185+
writeTask: StreamingDataWriterFactory[InternalRow],
174186
context: TaskContext,
175187
iter: Iterator[InternalRow]): WriterCommitMessage = {
176188
val epochCoordinator = EpochCoordinatorRef.get(
@@ -217,6 +229,17 @@ class InternalRowDataWriterFactory(
217229
rowWriterFactory: DataWriterFactory[Row],
218230
schema: StructType) extends DataWriterFactory[InternalRow] {
219231

232+
override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = {
233+
new InternalRowDataWriter(
234+
rowWriterFactory.createDataWriter(partitionId, attemptNumber),
235+
RowEncoder.apply(schema).resolveAndBind())
236+
}
237+
}
238+
239+
class StreamingInternalRowDataWriterFactory(
240+
rowWriterFactory: StreamingDataWriterFactory[Row],
241+
schema: StructType) extends StreamingDataWriterFactory[InternalRow] {
242+
220243
override def createDataWriter(
221244
partitionId: Int,
222245
attemptNumber: Int,

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ import scala.collection.JavaConverters._
2222
import org.apache.spark.internal.Logging
2323
import org.apache.spark.sql.{Row, SparkSession}
2424
import org.apache.spark.sql.sources.v2.DataSourceOptions
25-
import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage}
26-
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
25+
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage
26+
import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter}
2727
import org.apache.spark.sql.types.StructType
2828

2929
/** Common methods used to create writes for the the console sink */
@@ -39,7 +39,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions)
3939
assert(SparkSession.getActiveSession.isDefined)
4040
protected val spark = SparkSession.getActiveSession.get
4141

42-
def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory
42+
def createWriterFactory(): StreamingDataWriterFactory[Row] = PackedRowWriterFactory
4343

4444
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
4545
// We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.sources
2020
import org.apache.spark.sql.Row
2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage}
23-
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
23+
import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter}
2424

2525
/**
2626
* A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements
@@ -34,7 +34,13 @@ class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWr
3434

3535
override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages)
3636

37-
override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory()
37+
override def createWriterFactory(): StreamingDataWriterFactory[Row] = {
38+
writer.createWriterFactory() match {
39+
case s: StreamingDataWriterFactory[Row] => s
40+
case _ =>
41+
throw new IllegalStateException("StreamWriter did not give a StreamingDataWriterFactory")
42+
}
43+
}
3844
}
3945

4046
class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter)

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.collection.mutable
2222
import org.apache.spark.internal.Logging
2323
import org.apache.spark.sql.Row
2424
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage}
25+
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory
2526

2627
/**
2728
* A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery
@@ -30,7 +31,7 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat
3031
* Note that, because it sends all rows to the driver, this factory will generally be unsuitable
3132
* for production-quality sinks. It's intended for use in tests.
3233
*/
33-
case object PackedRowWriterFactory extends DataWriterFactory[Row] {
34+
case object PackedRowWriterFactory extends StreamingDataWriterFactory[Row] {
3435
override def createDataWriter(
3536
partitionId: Int,
3637
attemptNumber: Int,

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Comp
3131
import org.apache.spark.sql.execution.streaming.Sink
3232
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
3333
import org.apache.spark.sql.sources.v2.writer._
34-
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
34+
import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter}
3535
import org.apache.spark.sql.streaming.OutputMode
3636
import org.apache.spark.sql.types.StructType
3737

@@ -146,7 +146,7 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode)
146146
}
147147
}
148148

149-
case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] {
149+
case class MemoryWriterFactory(outputMode: OutputMode) extends StreamingDataWriterFactory[Row] {
150150
override def createDataWriter(
151151
partitionId: Int,
152152
attemptNumber: Int,

sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,7 @@ private[v2] object SimpleCounter {
207207
class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration)
208208
extends DataWriterFactory[Row] {
209209

210-
override def createDataWriter(
211-
partitionId: Int,
212-
attemptNumber: Int,
213-
epochId: Long): DataWriter[Row] = {
210+
override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = {
214211
val jobPath = new Path(new Path(path, "_temporary"), jobId)
215212
val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber")
216213
val fs = filePath.getFileSystem(conf.value)
@@ -243,10 +240,7 @@ class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] {
243240
class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration)
244241
extends DataWriterFactory[InternalRow] {
245242

246-
override def createDataWriter(
247-
partitionId: Int,
248-
attemptNumber: Int,
249-
epochId: Long): DataWriter[InternalRow] = {
243+
override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = {
250244
val jobPath = new Path(new Path(path, "_temporary"), jobId)
251245
val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber")
252246
val fs = filePath.getFileSystem(conf.value)

0 commit comments

Comments
 (0)