Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization


/**
* An ordered collection of offsets, used to track the progress of processing data from one or more
* [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance
Expand Down Expand Up @@ -70,13 +69,16 @@ object OffsetSeq {
* bound the lateness of data that will processed. Time unit: milliseconds
* @param batchTimestampMs: The current batch processing timestamp.
* Time unit: milliseconds
* @param conf: Additional conf_s to be persisted across batches, e.g. number of shuffle partitions.
*/
case class OffsetSeqMetadata(var batchWatermarkMs: Long = 0, var batchTimestampMs: Long = 0) {
case class OffsetSeqMetadata(
batchWatermarkMs: Long = 0,
batchTimestampMs: Long = 0,
conf: Map[String, String] = Map.empty) {
def json: String = Serialization.write(this)(OffsetSeqMetadata.format)
}

object OffsetSeqMetadata {
private implicit val format = Serialization.formats(NoTypeHints)
def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)
}

Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Curre
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.StreamingExplainCommand
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
import org.apache.spark.util.{Clock, UninterruptibleThread, Utils}

Expand Down Expand Up @@ -117,7 +118,9 @@ class StreamExecution(
}

/** Metadata associated with the offset seq of a batch in the query. */
protected var offsetSeqMetadata = OffsetSeqMetadata()
protected var offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0,
conf = Map(SQLConf.SHUFFLE_PARTITIONS.key ->
sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS).toString))

override val id: UUID = UUID.fromString(streamMetadata.id)

Expand Down Expand Up @@ -256,6 +259,15 @@ class StreamExecution(
updateStatusMessage("Initializing sources")
// force initialization of the logical plan so that the sources can be created
logicalPlan

// Isolated spark session to run the batches with.
val sparkSessionToRunBatches = sparkSession.cloneSession()
// Adaptive execution can change num shuffle partitions, disallow
sparkSessionToRunBatches.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")
offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove line.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this should be kept. It should use the conf in the cloned session.

conf = Map(SQLConf.SHUFFLE_PARTITIONS.key ->
sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)))

if (state.compareAndSet(INITIALIZING, ACTIVE)) {
// Unblock `awaitInitialization`
initializationLatch.countDown()
Expand All @@ -268,15 +280,15 @@ class StreamExecution(
reportTimeTaken("triggerExecution") {
if (currentBatchId < 0) {
// We'll do this initialization only once
populateStartOffsets()
populateStartOffsets(sparkSessionToRunBatches)
logDebug(s"Stream running from $committedOffsets to $availableOffsets")
} else {
constructNextBatch()
}
if (dataAvailable) {
currentStatus = currentStatus.copy(isDataAvailable = true)
updateStatusMessage("Processing new data")
runBatch()
runBatch(sparkSessionToRunBatches)
}
}

Expand Down Expand Up @@ -381,13 +393,32 @@ class StreamExecution(
* - committedOffsets
* - availableOffsets
*/
private def populateStartOffsets(): Unit = {
private def populateStartOffsets(sparkSessionToRunBatches: SparkSession): Unit = {
offsetLog.getLatest() match {
case Some((batchId, nextOffsets)) =>
logInfo(s"Resuming streaming query, starting with batch $batchId")
currentBatchId = batchId
availableOffsets = nextOffsets.toStreamProgress(sources)
offsetSeqMetadata = nextOffsets.metadata.getOrElse(OffsetSeqMetadata())

// update offset metadata
nextOffsets.metadata.foreach { metadata =>
val shufflePartitionsSparkSession: Int =
sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS)
val shufflePartitionsToUse = metadata.conf.getOrElse(SQLConf.SHUFFLE_PARTITIONS.key, {
// For backward compatibility, if # partitions was not recorded in the offset log,
// then ensure it is not missing. The new value is picked up from the conf.
logWarning("Number of shuffle partitions from previous run not found in checkpoint. "
+ s"Using the value from the conf, $shufflePartitionsSparkSession partitions.")
shufflePartitionsSparkSession
})
offsetSeqMetadata = OffsetSeqMetadata(
metadata.batchWatermarkMs, metadata.batchTimestampMs,
metadata.conf + (SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitionsToUse.toString))
// Update conf with correct number of shuffle partitions
sparkSessionToRunBatches.conf.set(
SQLConf.SHUFFLE_PARTITIONS.key, shufflePartitionsToUse.toString)
}

logDebug(s"Found possibly unprocessed offsets $availableOffsets " +
s"at batch timestamp ${offsetSeqMetadata.batchTimestampMs}")

Expand Down Expand Up @@ -444,25 +475,27 @@ class StreamExecution(
}
}
if (hasNewData) {
// Current batch timestamp in milliseconds
offsetSeqMetadata.batchTimestampMs = triggerClock.getTimeMillis()
var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs
// Update the eventTime watermark if we find one in the plan.
if (lastExecution != null) {
lastExecution.executedPlan.collect {
case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 =>
logDebug(s"Observed event time stats: ${e.eventTimeStats.value}")
e.eventTimeStats.value.max - e.delayMs
}.headOption.foreach { newWatermarkMs =>
if (newWatermarkMs > offsetSeqMetadata.batchWatermarkMs) {
if (newWatermarkMs > batchWatermarkMs) {
logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms")
offsetSeqMetadata.batchWatermarkMs = newWatermarkMs
batchWatermarkMs = newWatermarkMs
} else {
logDebug(
s"Event time didn't move: $newWatermarkMs < " +
s"${offsetSeqMetadata.batchWatermarkMs}")
s"$batchWatermarkMs")
}
}
}
offsetSeqMetadata = offsetSeqMetadata.copy(
batchWatermarkMs = batchWatermarkMs,
batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds

updateStatusMessage("Writing offsets to log")
reportTimeTaken("walCommit") {
Expand Down Expand Up @@ -505,8 +538,9 @@ class StreamExecution(

/**
* Processes any data available between `availableOffsets` and `committedOffsets`.
* @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with.
*/
private def runBatch(): Unit = {
private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = {
// Request unprocessed data from all sources.
newData = reportTimeTaken("getBatch") {
availableOffsets.flatMap {
Expand Down Expand Up @@ -551,7 +585,7 @@ class StreamExecution(

reportTimeTaken("queryPlanning") {
lastExecution = new IncrementalExecution(
sparkSession,
sparkSessionToRunBatch,
triggerLogicalPlan,
outputMode,
checkpointFile("state"),
Expand All @@ -561,7 +595,7 @@ class StreamExecution(
}

val nextBatch =
new Dataset(sparkSession, lastExecution, RowEncoder(lastExecution.analyzed.schema))
new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema))

reportTimeTaken("addBatch") {
sink.addBatch(currentBatchId, nextBatch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.mutable
import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
import org.apache.spark.sql.execution.streaming._
Expand All @@ -40,7 +41,7 @@ import org.apache.spark.util.{Clock, SystemClock, Utils}
*/
@Experimental
@InterfaceStability.Evolving
class StreamingQueryManager private[sql] (sparkSession: SparkSession) {
class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging {

private[sql] val stateStoreCoordinator =
StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env)
Expand Down Expand Up @@ -234,9 +235,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) {
}

if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) {
throw new AnalysisException(
s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} " +
"is not supported in streaming DataFrames/Datasets")
logWarning(s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} " +
"is not supported in streaming DataFrames/Datasets and will be disabled.")
}

new StreamingQueryWrapper(new StreamExecution(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"id":"dddc5e7f-1e71-454c-8362-de184444fb5a"}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
v1
{"batchWatermarkMs":0,"batchTimestampMs":1489180207737}
0
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
v1
{"batchWatermarkMs":0,"batchTimestampMs":1489180209261}
2
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.File

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.stringToFile
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext

class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext {
Expand All @@ -29,12 +30,37 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext {
case class StringOffset(override val json: String) extends Offset

test("OffsetSeqMetadata - deserialization") {
assert(OffsetSeqMetadata(0, 0) === OffsetSeqMetadata("""{}"""))
assert(OffsetSeqMetadata(1, 0) === OffsetSeqMetadata("""{"batchWatermarkMs":1}"""))
assert(OffsetSeqMetadata(0, 2) === OffsetSeqMetadata("""{"batchTimestampMs":2}"""))
assert(
OffsetSeqMetadata(1, 2) ===
OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}"""))
val key = SQLConf.SHUFFLE_PARTITIONS.key

def getConfWith(shufflePartitions: Int): Map[String, String] = {
Map(key -> shufflePartitions.toString)
}

// None set
assert(OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}"""))

// One set
assert(OffsetSeqMetadata(1, 0, Map.empty) === OffsetSeqMetadata("""{"batchWatermarkMs":1}"""))
assert(OffsetSeqMetadata(0, 2, Map.empty) === OffsetSeqMetadata("""{"batchTimestampMs":2}"""))
assert(OffsetSeqMetadata(0, 0, getConfWith(shufflePartitions = 2)) ===
OffsetSeqMetadata(s"""{"conf": {"$key":2}}"""))

// Two set
assert(OffsetSeqMetadata(1, 2, Map.empty) ===
OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}"""))
assert(OffsetSeqMetadata(1, 0, getConfWith(shufflePartitions = 3)) ===
OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"conf": {"$key":3}}"""))
assert(OffsetSeqMetadata(0, 2, getConfWith(shufflePartitions = 3)) ===
OffsetSeqMetadata(s"""{"batchTimestampMs":2,"conf": {"$key":3}}"""))

// All set
assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) ===
OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}}"""))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you add a test to verify that unknown fields don't break the serialization? Such as

    assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) ===
      OffsetSeqMetadata(
        s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}},"unknown":1"""))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.


// Drop unknown fields
assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) ===
OffsetSeqMetadata(
s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}},"unknown":1"""))
}

test("OffsetSeqLog - serialization - deserialization") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@

package org.apache.spark.sql.streaming

import java.io.{InterruptedIOException, IOException}
import java.io.{File, InterruptedIOException, IOException}
import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit}

import scala.reflect.ClassTag
import scala.util.control.ControlThrowable

import org.apache.commons.io.FileUtils

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

Expand Down Expand Up @@ -389,6 +392,102 @@ class StreamSuite extends StreamTest {
query.stop()
assert(query.exception.isEmpty)
}

test("SPARK-19873: streaming aggregation with change in number of partitions") {
val inputData = MemoryStream[(Int, Int)]
val agg = inputData.toDS().groupBy("_1").count()

testStream(agg, OutputMode.Complete())(
AddData(inputData, (1, 0), (2, 0)),
StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "2")),
CheckAnswer((1, 1), (2, 1)),
StopStream,
AddData(inputData, (3, 0), (2, 0)),
StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "5")),
CheckAnswer((1, 1), (2, 2), (3, 1)),
StopStream,
AddData(inputData, (3, 0), (1, 0)),
StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "1")),
CheckAnswer((1, 2), (2, 2), (3, 2)))
}

test("recover from a Spark v2.1 checkpoint") {
var inputData: MemoryStream[Int] = null
var query: DataStreamWriter[Row] = null

def prepareMemoryStream(): Unit = {
inputData = MemoryStream[Int]
inputData.addData(1, 2, 3, 4)
inputData.addData(3, 4, 5, 6)
inputData.addData(5, 6, 7, 8)

query = inputData
.toDF()
.groupBy($"value")
.agg(count("*"))
.writeStream
.outputMode("complete")
.format("memory")
}

// Get an existing checkpoint generated by Spark v2.1.
// v2.1 does not record # shuffle partitions in the offset metadata.
val resourceUri =
this.getClass.getResource("/structured-streaming/checkpoint-version-2.1.0").toURI
val checkpointDir = new File(resourceUri)

// 1 - Test if recovery from the checkpoint is successful.
prepareMemoryStream()
withTempDir { dir =>
// Copy the checkpoint to a temp dir to prevent changes to the original.
// Not doing this will lead to the test passing on the first run, but fail subsequent runs.
FileUtils.copyDirectory(checkpointDir, dir)

// Checkpoint data was generated by a query with 10 shuffle partitions.
// In order to test reading from the checkpoint, the checkpoint must have two or more batches,
// since the last batch may be rerun.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
var streamingQuery: StreamingQuery = null
try {
streamingQuery =
query.queryName("counts").option("checkpointLocation", dir.getCanonicalPath).start()
streamingQuery.processAllAvailable()
inputData.addData(9)
streamingQuery.processAllAvailable()

QueryTest.checkAnswer(spark.table("counts").toDF(),
Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) ::
Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil)
} finally {
if (streamingQuery ne null) {
streamingQuery.stop()
}
}
}
}

// 2 - Check recovery with wrong num shuffle partitions
prepareMemoryStream()
withTempDir { dir =>
FileUtils.copyDirectory(checkpointDir, dir)

// Since the number of partitions is greater than 10, should throw exception.
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "15") {
var streamingQuery: StreamingQuery = null
try {
intercept[StreamingQueryException] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the error message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

streamingQuery =
query.queryName("badQuery").option("checkpointLocation", dir.getCanonicalPath).start()
streamingQuery.processAllAvailable()
}
} finally {
if (streamingQuery ne null) {
streamingQuery.stop()
}
}
}
}
}
}

abstract class FakeSource extends StreamSourceProvider {
Expand Down
Loading