Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter}
import org.apache.spark.sql.types.StructType

/**
Expand Down Expand Up @@ -63,7 +63,7 @@ class KafkaStreamWriter(
*/
case class KafkaStreamWriterFactory(
topic: Option[String], producerParams: Map[String, String], schema: StructType)
extends DataWriterFactory[InternalRow] {
extends StreamingDataWriterFactory[InternalRow] {

override def createDataWriter(
partitionId: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.apache.spark.annotation.InterfaceStability;

/**
* A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is
* A data writer returned by {@link DataWriterFactory#createDataWriter(int, int)} and is
* responsible for writing data for an input RDD partition.
*
* One Spark task has one exclusive data writer, so there is no thread-safe concern.
Expand All @@ -39,7 +39,7 @@
* {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data
* writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an
* exception will be sent to the driver side, and Spark may retry this writing task a few times.
* In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a
* In each retry, {@link DataWriterFactory#createDataWriter(int, int)} will receive a
* different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])}
* when the configured number of retries is exhausted.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ public interface DataWriterFactory<T> extends Serializable {
* same task id but different attempt number, which means there are multiple
* tasks with the same task id running at the same time. Implementations can
* use this attempt number to distinguish writers of different task attempts.
* @param epochId A monotonically increasing id for streaming queries that are split in to
* discrete periods of execution. For non-streaming queries,
* this ID will always be 0.
*/
DataWriter<T> createDataWriter(int partitionId, int attemptNumber, long epochId);
DataWriter<T> createDataWriter(int partitionId, int attemptNumber);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.sources.v2.writer.streaming;

import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.sources.v2.writer.DataSourceWriter;
import org.apache.spark.sql.sources.v2.writer.DataWriter;
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage;
Expand All @@ -27,6 +28,9 @@
*
* Streaming queries are divided into intervals of data called epochs, with a monotonically
* increasing numeric ID. This writer handles commits and aborts for each successive epoch.
*
* Note that StreamWriter implementations should provide instances of
* {@link StreamingDataWriterFactory}.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about adding createStreamWriterFactory that returns the streaming interface? That would make it easier for implementations and prevent throwing cast exceptions because a StreamingDataWriterFactory is expected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That wouldn't be compatible with SupportsWriteInternalRow. We could add a StreamingSupportsWriteInternalRow, but that seems much more confusing both for Spark developers and for data source implementers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about removing the SupportsWriteInternalRow and always using InternalRow? For the read side, I think using Row and UnsafeRow is a problem: https://issues.apache.org/jira/browse/SPARK-23325

I don't see the value of using Row instead of InternalRow for readers, so maybe we should just simplify on both the read and write paths.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm broadly supportive. I'll detail my thoughts in the jira.

*/
@InterfaceStability.Evolving
public interface StreamWriter extends DataSourceWriter {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.sources.v2.writer.streaming;

import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.sources.v2.writer.DataWriter;
import org.apache.spark.sql.sources.v2.writer.DataWriterFactory;

@InterfaceStability.Evolving
public interface StreamingDataWriterFactory<T> extends DataWriterFactory<T> {
/**
* Returns a data writer to do the actual writing work.
*
* If this method fails (by throwing an exception), the action would fail and no Spark job was
* submitted.
*
* @param partitionId A unique id of the RDD partition that the returned writer will process.
* Usually Spark processes many RDD partitions at the same time,
* implementations should use the partition id to distinguish writers for
* different partitions.
* @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task
* failed, Spark launches a new task wth the same task id but different
* attempt number. Or a task is too slow, Spark launches new tasks wth the
* same task id but different attempt number, which means there are multiple
* tasks with the same task id running at the same time. Implementations can
* use this attempt number to distinguish writers of different task attempts.
* @param epochId A monotonically increasing id for streaming queries that are split in to
* discrete periods of execution. For non-streaming queries,
* this ID will always be 0.
*/
DataWriter<T> createDataWriter(int partitionId, int attemptNumber, long epochId);

@Override default DataWriter<T> createDataWriter(int partitionId, int attemptNumber) {
throw new IllegalStateException("Streaming data writer factory cannot create data writers without epoch.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why extend DataWriterFactory if this method is going to throw an exception? Why not make them independent interfaces?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's no common interface, DataSourceRDD would need to take a java.util.List[Any] instead of java.util.List[DataWriterFactory[T]]. This kind of pattern is present in a lot of DataSourceV2 interfaces, and I think it's endemic to the general design.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we could have it take a (partition, attempt number, epoch) => DataWriter lambda instead of Any if we really don't want to extend DataWriterFactory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, wrong side of the query. I meant DataWritingSparkTask.run().

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution}
import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions}
import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

Expand All @@ -54,7 +55,14 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
override protected def doExecute(): RDD[InternalRow] = {
val writeTask = writer match {
case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
case w: MicroBatchWriter =>
new StreamingInternalRowDataWriterFactory(w.createWriterFactory(), query.schema)
case w: StreamWriter =>
new StreamingInternalRowDataWriterFactory(
w.createWriterFactory().asInstanceOf[StreamingDataWriterFactory[Row]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will cause a cast exception, right? It think it is better to use a separate create method.

query.schema)
case _ =>
new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
}

val useCommitCoordinator = writer.useCommitCoordinator
Expand All @@ -75,7 +83,8 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
.askSync[Unit](SetWriterPartitions(rdd.getNumPartitions))

(context: TaskContext, iter: Iterator[InternalRow]) =>
DataWritingSparkTask.runContinuous(writeTask, context, iter)
DataWritingSparkTask.runContinuous(
writeTask.asInstanceOf[StreamingDataWriterFactory[InternalRow]], context, iter)
case _ =>
(context: TaskContext, iter: Iterator[InternalRow]) =>
DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator)
Expand Down Expand Up @@ -132,8 +141,13 @@ object DataWritingSparkTask extends Logging {
val stageId = context.stageId()
val partId = context.partitionId()
val attemptId = context.attemptNumber()
val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0")
val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong)
val dataWriter = writeTask match {
case w: StreamingDataWriterFactory[InternalRow] =>
val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).get
w.createDataWriter(partId, attemptId, epochId.toLong)

case w => w.createDataWriter(partId, attemptId)
}

// write the data and commit this writer.
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
Expand Down Expand Up @@ -170,7 +184,7 @@ object DataWritingSparkTask extends Logging {
}

def runContinuous(
writeTask: DataWriterFactory[InternalRow],
writeTask: StreamingDataWriterFactory[InternalRow],
context: TaskContext,
iter: Iterator[InternalRow]): WriterCommitMessage = {
val epochCoordinator = EpochCoordinatorRef.get(
Expand Down Expand Up @@ -217,6 +231,17 @@ class InternalRowDataWriterFactory(
rowWriterFactory: DataWriterFactory[Row],
schema: StructType) extends DataWriterFactory[InternalRow] {

override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = {
new InternalRowDataWriter(
rowWriterFactory.createDataWriter(partitionId, attemptNumber),
RowEncoder.apply(schema).resolveAndBind())
}
}

class StreamingInternalRowDataWriterFactory(
rowWriterFactory: StreamingDataWriterFactory[Row],
schema: StructType) extends StreamingDataWriterFactory[InternalRow] {

override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import scala.collection.JavaConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage
import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter}
import org.apache.spark.sql.types.StructType

/** Common methods used to create writes for the the console sink */
Expand All @@ -39,7 +39,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions)
assert(SparkSession.getActiveSession.isDefined)
protected val spark = SparkSession.getActiveSession.get

def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory
def createWriterFactory(): StreamingDataWriterFactory[Row] = PackedRowWriterFactory

override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
// We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.sources
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter}

/**
* A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements
Expand All @@ -34,7 +34,13 @@ class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWr

override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages)

override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory()
override def createWriterFactory(): StreamingDataWriterFactory[Row] = {
writer.createWriterFactory() match {
case s: StreamingDataWriterFactory[Row] => s
case _ =>
throw new IllegalStateException("StreamWriter did not give a StreamingDataWriterFactory")
}
}
}

class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory

/**
* A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery
Expand All @@ -30,7 +31,7 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat
* Note that, because it sends all rows to the driver, this factory will generally be unsuitable
* for production-quality sinks. It's intended for use in tests.
*/
case object PackedRowWriterFactory extends DataWriterFactory[Row] {
case object PackedRowWriterFactory extends StreamingDataWriterFactory[Row] {
override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Comp
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamWriter}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -146,7 +146,7 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode)
}
}

case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] {
case class MemoryWriterFactory(outputMode: OutputMode) extends StreamingDataWriterFactory[Row] {
override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,7 @@ private[v2] object SimpleCounter {
class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration)
extends DataWriterFactory[Row] {

override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
epochId: Long): DataWriter[Row] = {
override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = {
val jobPath = new Path(new Path(path, "_temporary"), jobId)
val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber")
val fs = filePath.getFileSystem(conf.value)
Expand Down Expand Up @@ -243,10 +240,7 @@ class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] {
class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration)
extends DataWriterFactory[InternalRow] {

override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
epochId: Long): DataWriter[InternalRow] = {
override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = {
val jobPath = new Path(new Path(path, "_temporary"), jobId)
val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber")
val fs = filePath.getFileSystem(conf.value)
Expand Down