Skip to content

Commit 9177ea3

Browse files
jhu-changzsxwing
authored andcommitted
[SPARK-11749][STREAMING] Duplicate creating the RDD in file stream when recovering from checkpoint data
Add a transient flag `DStream.restoredFromCheckpointData` to control the restore processing in DStream to avoid duplicate works: check this flag first in `DStream.restoreCheckpointData`, only when `false`, the restore process will be executed. Author: jhu-chang <[email protected]> Closes #9765 from jhu-chang/SPARK-11749. (cherry picked from commit f4346f6) Signed-off-by: Shixiong Zhu <[email protected]>
1 parent 4df1dd4 commit 9177ea3

File tree

2 files changed

+62
-9
lines changed

2 files changed

+62
-9
lines changed

streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ abstract class DStream[T: ClassTag] (
9797
private[streaming] val mustCheckpoint = false
9898
private[streaming] var checkpointDuration: Duration = null
9999
private[streaming] val checkpointData = new DStreamCheckpointData(this)
100+
@transient
101+
private var restoredFromCheckpointData = false
100102

101103
// Reference to whole DStream graph
102104
private[streaming] var graph: DStreamGraph = null
@@ -507,11 +509,14 @@ abstract class DStream[T: ClassTag] (
507509
* override the updateCheckpointData() method would also need to override this method.
508510
*/
509511
private[streaming] def restoreCheckpointData() {
510-
// Create RDDs from the checkpoint data
511-
logInfo("Restoring checkpoint data")
512-
checkpointData.restore()
513-
dependencies.foreach(_.restoreCheckpointData())
514-
logInfo("Restored checkpoint data")
512+
if (!restoredFromCheckpointData) {
513+
// Create RDDs from the checkpoint data
514+
logInfo("Restoring checkpoint data")
515+
checkpointData.restore()
516+
dependencies.foreach(_.restoreCheckpointData())
517+
restoredFromCheckpointData = true
518+
logInfo("Restored checkpoint data")
519+
}
515520
}
516521

517522
@throws(classOf[IOException])

streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.streaming
1919

20-
import java.io.{ObjectOutputStream, ByteArrayOutputStream, ByteArrayInputStream, File}
20+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, ObjectOutputStream}
2121

2222
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
2323
import scala.reflect.ClassTag
@@ -34,9 +34,30 @@ import org.scalatest.concurrent.Eventually._
3434
import org.scalatest.time.SpanSugar._
3535

3636
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils}
37-
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
37+
import org.apache.spark.rdd.RDD
38+
import org.apache.spark.streaming.dstream._
3839
import org.apache.spark.streaming.scheduler._
39-
import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils}
40+
import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, Utils}
41+
42+
/**
43+
* A input stream that records the times of restore() invoked
44+
*/
45+
private[streaming]
46+
class CheckpointInputDStream(ssc_ : StreamingContext) extends InputDStream[Int](ssc_) {
47+
protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData
48+
override def start(): Unit = { }
49+
override def stop(): Unit = { }
50+
override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.makeRDD(Seq(1)))
51+
private[streaming]
52+
class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) {
53+
@transient
54+
var restoredTimes = 0
55+
override def restore() {
56+
restoredTimes += 1
57+
super.restore()
58+
}
59+
}
60+
}
4061

4162
/**
4263
* A trait of that can be mixed in to get methods for testing DStream operations under
@@ -110,7 +131,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
110131
new StreamingContext(SparkContext.getOrCreate(conf), batchDuration)
111132
}
112133

113-
private def generateOutput[V: ClassTag](
134+
protected def generateOutput[V: ClassTag](
114135
ssc: StreamingContext,
115136
targetBatchTime: Time,
116137
checkpointDir: String,
@@ -715,6 +736,33 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester {
715736
}
716737
}
717738

739+
test("DStreamCheckpointData.restore invoking times") {
740+
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
741+
ssc.checkpoint(checkpointDir)
742+
val inputDStream = new CheckpointInputDStream(ssc)
743+
val checkpointData = inputDStream.checkpointData
744+
val mappedDStream = inputDStream.map(_ + 100)
745+
val outputStream = new TestOutputStreamWithPartitions(mappedDStream)
746+
outputStream.register()
747+
// do two more times output
748+
mappedDStream.foreachRDD(rdd => rdd.count())
749+
mappedDStream.foreachRDD(rdd => rdd.count())
750+
assert(checkpointData.restoredTimes === 0)
751+
val batchDurationMillis = ssc.progressListener.batchDuration
752+
generateOutput(ssc, Time(batchDurationMillis * 3), checkpointDir, stopSparkContext = true)
753+
assert(checkpointData.restoredTimes === 0)
754+
}
755+
logInfo("*********** RESTARTING ************")
756+
withStreamingContext(new StreamingContext(checkpointDir)) { ssc =>
757+
val checkpointData =
758+
ssc.graph.getInputStreams().head.asInstanceOf[CheckpointInputDStream].checkpointData
759+
assert(checkpointData.restoredTimes === 1)
760+
ssc.start()
761+
ssc.stop()
762+
assert(checkpointData.restoredTimes === 1)
763+
}
764+
}
765+
718766
// This tests whether spark can deserialize array object
719767
// refer to SPARK-5569
720768
test("recovery from checkpoint contains array object") {

0 commit comments

Comments
 (0)