1717
1818package org .apache .spark .streaming
1919
20- import java .io .{ObjectOutputStream , ByteArrayOutputStream , ByteArrayInputStream , File }
20+ import java .io .{ByteArrayInputStream , ByteArrayOutputStream , File , ObjectOutputStream }
2121
2222import scala .collection .mutable .{ArrayBuffer , SynchronizedBuffer }
2323import scala .reflect .ClassTag
@@ -34,9 +34,30 @@ import org.scalatest.concurrent.Eventually._
3434import org .scalatest .time .SpanSugar ._
3535
3636import org .apache .spark .{SparkConf , SparkContext , SparkFunSuite , TestUtils }
37- import org .apache .spark .streaming .dstream .{DStream , FileInputDStream }
37+ import org .apache .spark .rdd .RDD
38+ import org .apache .spark .streaming .dstream ._
3839import org .apache .spark .streaming .scheduler ._
39- import org .apache .spark .util .{MutableURLClassLoader , Clock , ManualClock , Utils }
40+ import org .apache .spark .util .{Clock , ManualClock , MutableURLClassLoader , Utils }
41+
42+ /**
43+ * A input stream that records the times of restore() invoked
44+ */
45+ private [streaming]
46+ class CheckpointInputDStream (ssc_ : StreamingContext ) extends InputDStream [Int ](ssc_) {
47+ protected [streaming] override val checkpointData = new FileInputDStreamCheckpointData
48+ override def start (): Unit = { }
49+ override def stop (): Unit = { }
50+ override def compute (time : Time ): Option [RDD [Int ]] = Some (ssc.sc.makeRDD(Seq (1 )))
51+ private [streaming]
52+ class FileInputDStreamCheckpointData extends DStreamCheckpointData (this ) {
53+ @ transient
54+ var restoredTimes = 0
55+ override def restore () {
56+ restoredTimes += 1
57+ super .restore()
58+ }
59+ }
60+ }
4061
4162/**
4263 * A trait of that can be mixed in to get methods for testing DStream operations under
@@ -110,7 +131,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
110131 new StreamingContext (SparkContext .getOrCreate(conf), batchDuration)
111132 }
112133
113- private def generateOutput [V : ClassTag ](
134+ protected def generateOutput [V : ClassTag ](
114135 ssc : StreamingContext ,
115136 targetBatchTime : Time ,
116137 checkpointDir : String ,
@@ -715,6 +736,33 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester {
715736 }
716737 }
717738
739+ test(" DStreamCheckpointData.restore invoking times" ) {
740+ withStreamingContext(new StreamingContext (conf, batchDuration)) { ssc =>
741+ ssc.checkpoint(checkpointDir)
742+ val inputDStream = new CheckpointInputDStream (ssc)
743+ val checkpointData = inputDStream.checkpointData
744+ val mappedDStream = inputDStream.map(_ + 100 )
745+ val outputStream = new TestOutputStreamWithPartitions (mappedDStream)
746+ outputStream.register()
747+ // do two more times output
748+ mappedDStream.foreachRDD(rdd => rdd.count())
749+ mappedDStream.foreachRDD(rdd => rdd.count())
750+ assert(checkpointData.restoredTimes === 0 )
751+ val batchDurationMillis = ssc.progressListener.batchDuration
752+ generateOutput(ssc, Time (batchDurationMillis * 3 ), checkpointDir, stopSparkContext = true )
753+ assert(checkpointData.restoredTimes === 0 )
754+ }
755+ logInfo(" *********** RESTARTING ************" )
756+ withStreamingContext(new StreamingContext (checkpointDir)) { ssc =>
757+ val checkpointData =
758+ ssc.graph.getInputStreams().head.asInstanceOf [CheckpointInputDStream ].checkpointData
759+ assert(checkpointData.restoredTimes === 1 )
760+ ssc.start()
761+ ssc.stop()
762+ assert(checkpointData.restoredTimes === 1 )
763+ }
764+ }
765+
718766 // This tests whether spark can deserialize array object
719767 // refer to SPARK-5569
720768 test(" recovery from checkpoint contains array object" ) {
0 commit comments