Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,119 @@

package org.apache.spark.sql.execution.streaming

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}

import org.apache.spark.TaskContext
import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType

/**
* A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by
* [[ForeachWriter]].
*
* @param writer The [[ForeachWriter]] to process all data.
* @tparam T The expected type of the sink.
*/
class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable {

override def addBatch(batchId: Long, data: DataFrame): Unit = {
// This logic should've been as simple as:
// ```
// data.as[T].foreachPartition { iter => ... }
// ```
//
// Unfortunately, doing that would just break the incremental planing. The reason is,
// `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will
// create a new plan. Because StreamExecution uses the existing plan to collect metrics and
// update watermark, we should never create a new plan. Otherwise, metrics and watermark are
// updated in the new plan, and StreamExecution cannot retrieval them.
//
// Hence, we need to manually convert internal rows to objects using encoder.

case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport {
override def createStreamWriter(
queryId: String,
schema: StructType,
mode: OutputMode,
options: DataSourceOptions): StreamWriter = {
val encoder = encoderFor[T].resolveAndBind(
data.logicalPlan.output,
data.sparkSession.sessionState.analyzer)
data.queryExecution.toRdd.foreachPartition { iter =>
if (writer.open(TaskContext.getPartitionId(), batchId)) {
try {
while (iter.hasNext) {
writer.process(encoder.fromRow(iter.next()))
}
} catch {
case e: Throwable =>
writer.close(e)
throw e
}
writer.close(null)
} else {
writer.close(null)
schema.toAttributes,
SparkSession.getActiveSession.get.sessionState.analyzer)
ForeachInternalWriter(writer, encoder)
}
}

case class ForeachInternalWriter[T: Encoder](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This is really a small class. Maybe inline this rather than define a confusing name...InternalWriter

writer: ForeachWriter[T], encoder: ExpressionEncoder[T])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: params on different lines

extends StreamWriter with SupportsWriteInternalRow {
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = {
ForeachWriterFactory(writer, encoder)
}
}

case class ForeachWriterFactory[T: Encoder](writer: ForeachWriter[T], encoder: ExpressionEncoder[T])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly ... maybe inline this class as well. its very small.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually.. probably should not inline this. its outer closure may not be serializable in that case.

extends DataWriterFactory[InternalRow] {
override def createDataWriter(partitionId: Int, attemptNumber: Int): ForeachDataWriter[T] = {
new ForeachDataWriter(writer, encoder, partitionId)
}
}

class ForeachDataWriter[T : Encoder](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docs describing the implementation of this DataWriter, especially the lifecycle of ForeachWriter (should go here than inline comments).

private var writer: ForeachWriter[T], encoder: ExpressionEncoder[T], partitionId: Int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

params in separate lines.

extends DataWriter[InternalRow] {
private val initialEpochId: Long = {
// Start with the microbatch ID. If it's not there, we're in continuous execution,
// so get the start epoch.
// This ID will be incremented as commits happen.
TaskContext.get().getLocalProperty(MicroBatchExecution.BATCH_ID_KEY) match {
case null => TaskContext.get().getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
case batch => batch.toLong
}
}
private var currentEpochId = initialEpochId

// The lifecycle of the ForeachWriter is incompatible with the lifecycle of DataSourceV2 writers.
// Unfortunately, we cannot migrate ForeachWriter, as its implementations live in user code. So
// we need a small state machine to shim between them.
// * CLOSED means close() has been called.
// * OPENED
private object WriterState extends Enumeration {
type WriterState = Value
val CLOSED, OPENED, OPENED_SKIP_PROCESSING = Value
}
import WriterState._

private var state = CLOSED

private def openAndSetState(epochId: Long) = {
// Create a new writer by roundtripping through the serialization for compatibility.
// In the old API, a writer instantiation would never get reused.
val byteStream = new ByteArrayOutputStream()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you serializing and deserializing here? If you are reserializing the ForeachWriter, doesnt this mean that you are going to retain state (of the non-transient fields) across them? Is that what you want?

seems the best thing to do is to serialize the writer at the driver, send the bytes to the task, and then deserialize repeatedly. then you only incur the cost of deserializing between epochs and you always start with a fresh copy of the ForeachWriter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right; this suggestion is what we really want.

val objectStream = new ObjectOutputStream(byteStream)
objectStream.writeObject(writer)
writer = new ObjectInputStream(new ByteArrayInputStream(byteStream.toByteArray)).readObject()
.asInstanceOf[ForeachWriter[T]]

writer.open(partitionId, epochId) match {
case true => state = OPENED
case false => state = OPENED_SKIP_PROCESSING
}
}

openAndSetState(initialEpochId)

override def write(record: InternalRow): Unit = {
try {
state match {
case OPENED => writer.process(encoder.fromRow(record))
case OPENED_SKIP_PROCESSING => ()
case CLOSED =>
// First record of a new epoch, so we need to open a new writer for it.
currentEpochId += 1
openAndSetState(currentEpochId)
writer.process(encoder.fromRow(record))
}
} catch {
case t: Throwable =>
writer.close(t)
throw t
}
}

override def toString(): String = "ForeachSink"
override def commit(): WriterCommitMessage = {
writer.close(null)
ForeachWriterCommitMessage
}

override def abort(): Unit = {}
}

case object ForeachWriterCommitMessage extends WriterCommitMessage
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,9 @@ class MicroBatchExecution(
case _ => throw new IllegalArgumentException(s"unknown sink type for $sink")
}

sparkSession.sparkContext.setLocalProperty(
MicroBatchExecution.BATCH_ID_KEY, currentBatchId.toString)

reportTimeTaken("queryPlanning") {
lastExecution = new IncrementalExecution(
sparkSessionToRunBatch,
Expand Down Expand Up @@ -500,3 +503,7 @@ class MicroBatchExecution(
Optional.ofNullable(scalaOption.orNull)
}
}

object MicroBatchExecution {
val BATCH_ID_KEY = "sql.streaming.microbatch.batchId"
}
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
query
} else if (source == "foreach") {
assertNotPartitioned("foreach")
val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc)
val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc)
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
query.processAllAvailable()
}
assert(e.getCause.isInstanceOf[SparkException])
assert(e.getCause.getCause.getMessage === "error")
assert(e.getCause.getCause.getCause.getMessage === "error")
assert(query.isActive === false)

val allEvents = ForeachSinkSuite.allEvents()
Expand Down Expand Up @@ -255,6 +255,32 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
query.stop()
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there should be a test with continuous processing + foreach.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good instinct, it didn't quite work. Added the test.

testQuietly("foreach does not reuse writers") {
withTempDir { checkpointDir =>
val input = MemoryStream[Int]
val query = input.toDS().repartition(1).writeStream
.option("checkpointLocation", checkpointDir.getCanonicalPath)
.foreach(new TestForeachWriter() {
override def process(value: Int): Unit = {
super.process(this.hashCode())
}
}).start()
input.addData(0)
query.processAllAvailable()
input.addData(0)
query.processAllAvailable()

val allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 2)
assert(allEvents(0)(1).isInstanceOf[ForeachSinkSuite.Process[Int]])
val firstWriterId = allEvents(0)(1).asInstanceOf[ForeachSinkSuite.Process[Int]].value
assert(allEvents(1)(1).isInstanceOf[ForeachSinkSuite.Process[Int]])
assert(
allEvents(1)(1).asInstanceOf[ForeachSinkSuite.Process[Int]].value != firstWriterId,
"writer was reused!")
}
}
}

/** A global object to collect events in the executor */
Expand Down