-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23099][SS] Migrate foreach sink to DataSourceV2 #20552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
72472ea
87d0bc8
23e4138
4dfe57d
a33a35c
66270c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,52 +17,119 @@ | |
|
|
||
| package org.apache.spark.sql.execution.streaming | ||
|
|
||
| import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} | ||
|
|
||
| import org.apache.spark.TaskContext | ||
| import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter} | ||
| import org.apache.spark.sql.catalyst.encoders.encoderFor | ||
| import org.apache.spark.sql._ | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} | ||
| import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution | ||
| import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} | ||
| import org.apache.spark.sql.sources.v2.writer._ | ||
| import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter | ||
| import org.apache.spark.sql.streaming.OutputMode | ||
| import org.apache.spark.sql.types.StructType | ||
|
|
||
| /** | ||
| * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by | ||
| * [[ForeachWriter]]. | ||
| * | ||
| * @param writer The [[ForeachWriter]] to process all data. | ||
| * @tparam T The expected type of the sink. | ||
| */ | ||
| class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable { | ||
|
|
||
| override def addBatch(batchId: Long, data: DataFrame): Unit = { | ||
| // This logic should've been as simple as: | ||
| // ``` | ||
| // data.as[T].foreachPartition { iter => ... } | ||
| // ``` | ||
| // | ||
| // Unfortunately, doing that would just break the incremental planing. The reason is, | ||
| // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will | ||
| // create a new plan. Because StreamExecution uses the existing plan to collect metrics and | ||
| // update watermark, we should never create a new plan. Otherwise, metrics and watermark are | ||
| // updated in the new plan, and StreamExecution cannot retrieval them. | ||
| // | ||
| // Hence, we need to manually convert internal rows to objects using encoder. | ||
|
|
||
| case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { | ||
| override def createStreamWriter( | ||
| queryId: String, | ||
| schema: StructType, | ||
| mode: OutputMode, | ||
| options: DataSourceOptions): StreamWriter = { | ||
| val encoder = encoderFor[T].resolveAndBind( | ||
| data.logicalPlan.output, | ||
| data.sparkSession.sessionState.analyzer) | ||
| data.queryExecution.toRdd.foreachPartition { iter => | ||
| if (writer.open(TaskContext.getPartitionId(), batchId)) { | ||
| try { | ||
| while (iter.hasNext) { | ||
| writer.process(encoder.fromRow(iter.next())) | ||
| } | ||
| } catch { | ||
| case e: Throwable => | ||
| writer.close(e) | ||
| throw e | ||
| } | ||
| writer.close(null) | ||
| } else { | ||
| writer.close(null) | ||
| schema.toAttributes, | ||
| SparkSession.getActiveSession.get.sessionState.analyzer) | ||
| ForeachInternalWriter(writer, encoder) | ||
| } | ||
| } | ||
|
|
||
| case class ForeachInternalWriter[T: Encoder]( | ||
| writer: ForeachWriter[T], encoder: ExpressionEncoder[T]) | ||
|
||
| extends StreamWriter with SupportsWriteInternalRow { | ||
| override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} | ||
| override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} | ||
|
|
||
| override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { | ||
| ForeachWriterFactory(writer, encoder) | ||
| } | ||
| } | ||
|
|
||
| case class ForeachWriterFactory[T: Encoder](writer: ForeachWriter[T], encoder: ExpressionEncoder[T]) | ||
|
||
| extends DataWriterFactory[InternalRow] { | ||
| override def createDataWriter(partitionId: Int, attemptNumber: Int): ForeachDataWriter[T] = { | ||
| new ForeachDataWriter(writer, encoder, partitionId) | ||
| } | ||
| } | ||
|
|
||
| class ForeachDataWriter[T : Encoder]( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add docs describing the implementation of this DataWriter, especially the lifecycle of ForeachWriter (should go here than inline comments). |
||
| private var writer: ForeachWriter[T], encoder: ExpressionEncoder[T], partitionId: Int) | ||
|
||
| extends DataWriter[InternalRow] { | ||
| private val initialEpochId: Long = { | ||
| // Start with the microbatch ID. If it's not there, we're in continuous execution, | ||
| // so get the start epoch. | ||
| // This ID will be incremented as commits happen. | ||
| TaskContext.get().getLocalProperty(MicroBatchExecution.BATCH_ID_KEY) match { | ||
| case null => TaskContext.get().getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong | ||
| case batch => batch.toLong | ||
| } | ||
| } | ||
| private var currentEpochId = initialEpochId | ||
|
|
||
| // The lifecycle of the ForeachWriter is incompatible with the lifecycle of DataSourceV2 writers. | ||
| // Unfortunately, we cannot migrate ForeachWriter, as its implementations live in user code. So | ||
| // we need a small state machine to shim between them. | ||
| // * CLOSED means close() has been called. | ||
| // * OPENED | ||
| private object WriterState extends Enumeration { | ||
| type WriterState = Value | ||
| val CLOSED, OPENED, OPENED_SKIP_PROCESSING = Value | ||
| } | ||
| import WriterState._ | ||
|
|
||
| private var state = CLOSED | ||
|
|
||
| private def openAndSetState(epochId: Long) = { | ||
| // Create a new writer by roundtripping through the serialization for compatibility. | ||
| // In the old API, a writer instantiation would never get reused. | ||
| val byteStream = new ByteArrayOutputStream() | ||
|
||
| val objectStream = new ObjectOutputStream(byteStream) | ||
| objectStream.writeObject(writer) | ||
| writer = new ObjectInputStream(new ByteArrayInputStream(byteStream.toByteArray)).readObject() | ||
| .asInstanceOf[ForeachWriter[T]] | ||
|
|
||
| writer.open(partitionId, epochId) match { | ||
| case true => state = OPENED | ||
| case false => state = OPENED_SKIP_PROCESSING | ||
| } | ||
| } | ||
|
|
||
| openAndSetState(initialEpochId) | ||
|
|
||
| override def write(record: InternalRow): Unit = { | ||
| try { | ||
| state match { | ||
| case OPENED => writer.process(encoder.fromRow(record)) | ||
| case OPENED_SKIP_PROCESSING => () | ||
| case CLOSED => | ||
| // First record of a new epoch, so we need to open a new writer for it. | ||
| currentEpochId += 1 | ||
| openAndSetState(currentEpochId) | ||
| writer.process(encoder.fromRow(record)) | ||
| } | ||
| } catch { | ||
| case t: Throwable => | ||
| writer.close(t) | ||
| throw t | ||
| } | ||
| } | ||
|
|
||
| override def toString(): String = "ForeachSink" | ||
| override def commit(): WriterCommitMessage = { | ||
| writer.close(null) | ||
| ForeachWriterCommitMessage | ||
| } | ||
|
|
||
| override def abort(): Unit = {} | ||
| } | ||
|
|
||
| case object ForeachWriterCommitMessage extends WriterCommitMessage | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -141,7 +141,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf | |
| query.processAllAvailable() | ||
| } | ||
| assert(e.getCause.isInstanceOf[SparkException]) | ||
| assert(e.getCause.getCause.getMessage === "error") | ||
| assert(e.getCause.getCause.getCause.getMessage === "error") | ||
| assert(query.isActive === false) | ||
|
|
||
| val allEvents = ForeachSinkSuite.allEvents() | ||
|
|
@@ -255,6 +255,32 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf | |
| query.stop() | ||
| } | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there should be a test with continuous processing + foreach.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good instinct, it didn't quite work. Added the test. |
||
| testQuietly("foreach does not reuse writers") { | ||
| withTempDir { checkpointDir => | ||
| val input = MemoryStream[Int] | ||
| val query = input.toDS().repartition(1).writeStream | ||
| .option("checkpointLocation", checkpointDir.getCanonicalPath) | ||
| .foreach(new TestForeachWriter() { | ||
| override def process(value: Int): Unit = { | ||
| super.process(this.hashCode()) | ||
| } | ||
| }).start() | ||
| input.addData(0) | ||
| query.processAllAvailable() | ||
| input.addData(0) | ||
| query.processAllAvailable() | ||
|
|
||
| val allEvents = ForeachSinkSuite.allEvents() | ||
| assert(allEvents.size === 2) | ||
| assert(allEvents(0)(1).isInstanceOf[ForeachSinkSuite.Process[Int]]) | ||
| val firstWriterId = allEvents(0)(1).asInstanceOf[ForeachSinkSuite.Process[Int]].value | ||
| assert(allEvents(1)(1).isInstanceOf[ForeachSinkSuite.Process[Int]]) | ||
| assert( | ||
| allEvents(1)(1).asInstanceOf[ForeachSinkSuite.Process[Int]].value != firstWriterId, | ||
| "writer was reused!") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** A global object to collect events in the executor */ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This is really a small class. Maybe inline this rather than define a confusing name
...InternalWriter