Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,22 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria
// method supporting incremental planning. But in the long run, we should generally make newly
// created Datasets use `IncrementalExecution` where necessary (which is SPARK-16264 tries to
// resolve).

val incrementalExecution = data.queryExecution.asInstanceOf[IncrementalExecution]
val datasetWithIncrementalExecution =
new Dataset(data.sparkSession, data.logicalPlan, implicitly[Encoder[T]]) {
new Dataset(data.sparkSession, incrementalExecution, implicitly[Encoder[T]]) {
override lazy val rdd: RDD[T] = {
val objectType = exprEnc.deserializer.dataType
val deserialized = CatalystSerde.deserialize[T](logicalPlan)

// was originally: sparkSession.sessionState.executePlan(deserialized) ...
val incrementalExecution = new IncrementalExecution(
val newIncrementalExecution = new IncrementalExecution(
this.sparkSession,
deserialized,
data.queryExecution.asInstanceOf[IncrementalExecution].outputMode,
data.queryExecution.asInstanceOf[IncrementalExecution].checkpointLocation,
data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId,
data.queryExecution.asInstanceOf[IncrementalExecution].currentEventTimeWatermark)
incrementalExecution.toRdd.mapPartitions { rows =>
incrementalExecution.outputMode,
incrementalExecution.checkpointLocation,
incrementalExecution.currentBatchId,
incrementalExecution.currentEventTimeWatermark)
newIncrementalExecution.toRdd.mapPartitions { rows =>
rows.map(_.get(0, objectType))
}.asInstanceOf[RDD[T]]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkException
import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.functions.{count, window}
import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest}
import org.apache.spark.sql.test.SharedSQLContext

Expand Down Expand Up @@ -169,6 +170,40 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
assert(errorEvent.error.get.getMessage === "error")
}
}

test("foreach with watermark") {
val inputData = MemoryStream[Int]

val windowedAggregation = inputData.toDF()
.withColumn("eventTime", $"value".cast("timestamp"))
.withWatermark("eventTime", "10 seconds")
.groupBy(window($"eventTime", "5 seconds") as 'window)
.agg(count("*") as 'count)
.select($"count".as[Long])
.map(_.toInt)
.repartition(1)

val query = windowedAggregation
.writeStream
.outputMode(OutputMode.Complete)
.foreach(new TestForeachWriter())
.start()
try {
inputData.addData(10, 11, 12)
query.processAllAvailable()

val allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 1)
val expectedEvents = Seq(
ForeachSinkSuite.Open(partition = 0, version = 0),
ForeachSinkSuite.Process(value = 3),
ForeachSinkSuite.Close(None)
)
assert(allEvents === Seq(expectedEvents))
} finally {
query.stop()
}
}
}

/** A global object to collect events in the executor */
Expand Down