diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b8414b5d099c..832a497b2c40 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2350,6 +2350,16 @@ object SparkContext extends Logging { } } + private[spark] def getActiveContext(): Option[SparkContext] = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + Option(activeContext.get()) + } + } + + private[spark] def stopActiveContext(): Unit = { + getActiveContext().foreach(_.stop()) + } + /** * Called at the end of the SparkContext constructor to ensure that no other SparkContext has * raced with this constructor and started. diff --git a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index cfedb5a042a3..469d206bd5c3 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -18,6 +18,7 @@ package org.apache.spark.streaming; import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext$; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.junit.After; import org.junit.Before; @@ -28,6 +29,7 @@ public abstract class LocalJavaStreamingContext { @Before public void setUp() { + SparkContext$.MODULE$.stopActiveContext(); SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 4e702bbb9206..5a473aa3fcd3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -620,7 +620,7 @@ class BasicOperationsSuite extends TestSuiteBase { } test("slice") { - withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc => + withStreamingContext(Seconds(1)) { ssc => val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) val stream = new TestInputStream[Int](ssc, input, 2) stream.foreachRDD(_ => {}) // Dummy output stream @@ -637,7 +637,7 @@ class BasicOperationsSuite extends TestSuiteBase { } } test("slice - has not been initialized") { - withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc => + withStreamingContext(Seconds(1)) { ssc => val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) val stream = new TestInputStream[Int](ssc, input, 2) val thrown = intercept[SparkException] { @@ -657,7 +657,7 @@ class BasicOperationsSuite extends TestSuiteBase { .window(Seconds(4), Seconds(2)) } - val operatedStream = runCleanupTest(conf, operation _, + val operatedStream = runCleanupTest(operation _, numExpectedOutput = cleanupTestInput.size / 2, rememberDuration = Seconds(3)) val windowedStream2 = operatedStream.asInstanceOf[WindowedDStream[_]] val windowedStream1 = windowedStream2.dependencies.head.asInstanceOf[WindowedDStream[_]] @@ -694,7 +694,7 @@ class BasicOperationsSuite extends TestSuiteBase { Some(values.sum + state.getOrElse(0)) } val stateStream = runCleanupTest( - conf, _.map(_ -> 1).updateStateByKey(updateFunc).checkpoint(Seconds(3))) + _.map(_ -> 1).updateStateByKey(updateFunc).checkpoint(Seconds(3))) assert(stateStream.rememberDuration === stateStream.checkpointDuration * 2) assert(stateStream.generatedRDDs.contains(Time(10000))) @@ -705,7 +705,7 @@ class BasicOperationsSuite extends TestSuiteBase { // Actually receive data over through receiver to create BlockRDDs withTestServer(new TestServer()) { testServer => - withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + withStreamingContext { ssc => testServer.start() val batchCounter = new BatchCounter(ssc) @@ -781,7 +781,6 @@ class BasicOperationsSuite extends TestSuiteBase { /** Test cleanup of RDDs in DStream metadata */ def runCleanupTest[T: ClassTag]( - conf2: SparkConf, operation: DStream[Int] => DStream[T], numExpectedOutput: Int = cleanupTestInput.size, rememberDuration: Duration = null diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index a1e9d1e02380..1099750b5148 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -211,6 +211,8 @@ trait DStreamCheckpointTester { self: SparkFunSuite => class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester with ResetSystemProperties { + override val reuseContext: Boolean = false + var ssc: StreamingContext = null override def batchDuration: Duration = Milliseconds(500) @@ -238,8 +240,6 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") - conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - val stateStreamCheckpointInterval = Seconds(1) val fs = FileSystem.getLocal(new Configuration()) // this ensure checkpointing occurs at least once @@ -571,7 +571,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester } test("recovery maintains rate controller") { - ssc = new StreamingContext(conf, batchDuration) + ssc = new StreamingContext(sc, batchDuration) ssc.checkpoint(checkpointDir) val dstream = new RateTestInputDStream(ssc) { @@ -635,7 +635,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester try { // This is a var because it's re-assigned when we restart from a checkpoint var clock: ManualClock = null - withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + withStreamingContext(batchDuration) { ssc => ssc.checkpoint(checkpointDir) clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val batchCounter = new BatchCounter(ssc) @@ -760,7 +760,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester } test("DStreamCheckpointData.restore invoking times") { - withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + withStreamingContext { ssc => ssc.checkpoint(checkpointDir) val inputDStream = new CheckpointInputDStream(ssc) val checkpointData = inputDStream.checkpointData @@ -822,7 +822,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester val jobGenerator = mock(classOf[JobGenerator]) val checkpointDir = Utils.createTempDir().toString val checkpointWriter = - new CheckpointWriter(jobGenerator, conf, checkpointDir, new Configuration()) + new CheckpointWriter(jobGenerator, sc.conf, checkpointDir, new Configuration()) val bytes1 = Array.fill[Byte](10)(1) new checkpointWriter.CheckpointWriteHandler( Time(2000), bytes1, clearCheckpointDataLater = false).run() @@ -869,6 +869,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester // Therefore SPARK-6847 introduces "spark.checkpoint.checkpointAllMarked" to force checkpointing // all marked RDDs in the DAG to resolve this issue. (For the previous example, it will break // connections between layer 2 and layer 3) + stopActiveContext() ssc = new StreamingContext(master, framework, batchDuration) val batchCounter = new BatchCounter(ssc) ssc.checkpoint(checkpointDir) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala index 2ab600ab817e..998725136798 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -19,9 +19,7 @@ package org.apache.spark.streaming import java.io.NotSerializableException -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.{HashPartitioner, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.{HashPartitioner, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.ReturnStatementInClosureException @@ -29,18 +27,17 @@ import org.apache.spark.util.ReturnStatementInClosureException /** * Test that closures passed to DStream operations are actually cleaned. */ -class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { - private var ssc: StreamingContext = null +class DStreamClosureSuite extends ReuseableSparkContext { + private var ssc: StreamingContext = _ override def beforeAll(): Unit = { super.beforeAll() - val sc = new SparkContext("local", "test") ssc = new StreamingContext(sc, Seconds(1)) } override def afterAll(): Unit = { try { - ssc.stop(stopSparkContext = true) + ssc.stop() ssc = null } finally { super.afterAll() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala index 94f1bcebc3a3..946e12b4ae2c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -30,20 +30,23 @@ import org.apache.spark.util.ManualClock /** * Tests whether scope information is passed from DStream operations to RDDs correctly. */ -class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { - private var ssc: StreamingContext = null - private val batchDuration: Duration = Seconds(1) +class DStreamScopeSuite extends ReuseableSparkContext { + private var ssc: StreamingContext = _ + + // Configurations to add to a new or existing spark context. + override def extraSparkConf: Map[String, String] = { + // Use a manual clock + super.extraSparkConf ++ Map("spark.streaming.clock" -> "org.apache.spark.util.ManualClock") + } override def beforeAll(): Unit = { super.beforeAll() - val conf = new SparkConf().setMaster("local").setAppName("test") - conf.set("spark.streaming.clock", classOf[ManualClock].getName()) - ssc = new StreamingContext(new SparkContext(conf), batchDuration) + ssc = new StreamingContext(sc, Seconds(1)) } override def afterAll(): Unit = { try { - ssc.stop(stopSparkContext = true) + ssc.stop() } finally { super.afterAll() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 19ceb748e07f..1cc80fc5aca7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -35,6 +35,11 @@ class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging { private val numBatches = 30 private var directory: File = null + override protected def beforeAll(): Unit = { + super.beforeAll() + SparkContext.getActiveContext().foreach(_.stop()) + } + before { directory = Utils.createTempDir() } @@ -46,7 +51,7 @@ class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging { StreamingContext.getActive().foreach { _.stop() } // Stop SparkContext if active - SparkContext.getOrCreate(new SparkConf().setMaster("local").setAppName("bla")).stop() + SparkContext.getActiveContext().foreach(_.stop()) } test("multiple failures with map") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 9ecfa48091a0..dc5f6d0f42b7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -49,7 +49,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { testServer.start() // Set up the streaming context and input streams - withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + withStreamingContext { ssc => ssc.addStreamingListener(ssc.progressListener) val input = Seq(1, 2, 3, 4, 5) @@ -112,7 +112,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { withTestServer(new TestServer()) { testServer => testServer.start() - withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + withStreamingContext { ssc => ssc.addStreamingListener(ssc.progressListener) val batchCounter = new BatchCounter(ssc) @@ -149,7 +149,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) // Set up the streaming context and input streams - withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + withStreamingContext(batchDuration) { ssc => val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] // This `setTime` call ensures that the clock is past the creation time of `existingFile` clock.setTime(existingFile.lastModified + batchDuration.milliseconds) @@ -213,7 +213,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val pathWithWildCard = testDir.toString + "/*/" // Set up the streaming context and input streams - withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + withStreamingContext(batchDuration) { ssc => val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] clock.setTime(existingFile.lastModified + batchDuration.milliseconds) val batchCounter = new BatchCounter(ssc) @@ -270,7 +270,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { def output: Iterable[Long] = outputQueue.asScala.flatMap(x => x) // set up the network stream using the test receiver - withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + withStreamingContext { ssc => val networkStream = ssc.receiverStream[Int](testReceiver) val countStream = networkStream.count @@ -305,7 +305,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { def output: Iterable[Seq[String]] = outputQueue.asScala.filter(_.nonEmpty) // Set up the streaming context and input streams - withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + withStreamingContext { ssc => val queue = new mutable.Queue[RDD[String]]() val queueStream = ssc.queueStream(queue, oneAtATime = true) val outputStream = new TestOutputStream(queueStream, outputQueue) @@ -350,7 +350,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val expectedOutput = Seq(Seq("1", "2", "3"), Seq("4", "5")) // Set up the streaming context and input streams - withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + withStreamingContext { ssc => val queue = new mutable.Queue[RDD[String]]() val queueStream = ssc.queueStream(queue, oneAtATime = false) val outputStream = new TestOutputStream(queueStream, outputQueue) @@ -396,7 +396,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("test track the number of input stream") { - withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + withStreamingContext { ssc => class TestInputDStream extends InputDStream[String](ssc) { def start() {} @@ -434,7 +434,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) // Set up the streaming context and input streams - withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + withStreamingContext(batchDuration) { ssc => val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] // This `setTime` call ensures that the clock is past the creation time of `existingFile` clock.setTime(existingFile.lastModified + batchDuration.milliseconds) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 3b662ec1833a..cd2452053055 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -23,20 +23,21 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.PrivateMethodTester._ -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.streaming.dstream.{DStream, InternalMapWithStateDStream, MapWithStateDStream, MapWithStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} -class MapWithStateSuite extends SparkFunSuite - with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter { +class MapWithStateSuite extends ReuseableSparkContext with DStreamCheckpointTester { - private var sc: SparkContext = null protected var checkpointDir: File = null protected val batchDuration = Seconds(1) + override def extraSparkConf: Map[String, String] = { + // Use a manual clock + super.extraSparkConf ++ Map("spark.streaming.clock" -> "org.apache.spark.util.ManualClock") + } + before { StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } checkpointDir = Utils.createTempDir("checkpoint") @@ -49,23 +50,6 @@ class MapWithStateSuite extends SparkFunSuite } } - override def beforeAll(): Unit = { - super.beforeAll() - val conf = new SparkConf().setMaster("local").setAppName("MapWithStateSuite") - conf.set("spark.streaming.clock", classOf[ManualClock].getName()) - sc = new SparkContext(conf) - } - - override def afterAll(): Unit = { - try { - if (sc != null) { - sc.stop() - } - } finally { - super.afterAll() - } - } - test("state - get, exists, update, remove, ") { var state: StateImpl[Int] = null diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index f2241936000a..ed48d4998d55 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -75,6 +75,11 @@ class ReceivedBlockHandlerSuite var storageLevel: StorageLevel = null var tempDirectory: File = null + override def beforeAll(): Unit = { + super.beforeAll() + SparkContext.getActiveContext().foreach(_.stop()) + } + before { rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.driver.port", rpcEnv.address.port.toString) @@ -107,6 +112,8 @@ class ReceivedBlockHandlerSuite rpcEnv.awaitTermination() rpcEnv = null + sc.stop() + Utils.deleteRecursively(tempDirectory) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala index 0349e11224cf..901ea886deac 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.SparkEnv import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -32,6 +32,8 @@ import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLog class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { + override def master: String = "local[4]" + override def afterAll(): Unit = { try { StreamingContext.getActive().foreach(_.stop()) @@ -123,16 +125,18 @@ class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { } private def runTest(enableWAL: Boolean, body: ReceiverInputDStream[_] => Unit): Unit = { - val conf = new SparkConf() - conf.setMaster("local[4]").setAppName("ReceiverInputDStreamSuite") - conf.set(WriteAheadLogUtils.RECEIVER_WAL_ENABLE_CONF_KEY, enableWAL.toString) - require(WriteAheadLogUtils.enableReceiverLog(conf) === enableWAL) - val ssc = new StreamingContext(conf, Seconds(1)) - val receiverStream = new ReceiverInputDStream[Int](ssc) { - override def getReceiver(): Receiver[Int] = null - } - withStreamingContext(ssc) { ssc => - body(receiverStream) + try { + sc.conf.set(WriteAheadLogUtils.RECEIVER_WAL_ENABLE_CONF_KEY, enableWAL.toString) + assert(WriteAheadLogUtils.enableReceiverLog(sc.conf) === enableWAL) + val ssc = new StreamingContext(sc, Seconds(1)) + val receiverStream = new ReceiverInputDStream[Int](ssc) { + override def getReceiver(): Receiver[Int] = null + } + withStreamingContext(ssc) { ssc => + body(receiverStream) + } + } finally { + sc.conf.remove(WriteAheadLogUtils.RECEIVER_WAL_ENABLE_CONF_KEY) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 1b1e21f6e5ba..da33fb8c5758 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver._ @@ -194,6 +194,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { * WALs should be cleaned later. */ test("write ahead log - generating and cleaning") { + SparkContext.getActiveContext().foreach(_.stop()) val sparkConf = new SparkConf() .setMaster("local[4]") // must be at least 3 as we are going to start 2 receivers .setAppName(framework) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 5645996de5a6..81bd3e52d1b3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -53,6 +53,10 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo var sc: SparkContext = null var ssc: StreamingContext = null + before { + SparkContext.getActiveContext().foreach(_.stop()) + } + after { if (ssc != null) { ssc.stop() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 0f957a1b5570..e949a5651b79 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -38,6 +38,8 @@ import org.apache.spark.streaming.scheduler._ class StreamingListenerSuite extends TestSuiteBase with Matchers { + override val reuseContext: Boolean = false + val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index fa975a146216..59b2c6de9742 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming import java.io.{IOException, ObjectInputStream} import java.util.concurrent.ConcurrentLinkedQueue +import scala.collection.mutable import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -29,7 +30,7 @@ import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration import org.scalatest.time.{Seconds => ScalaTestSeconds, Span} -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.{DStream, ForEachDStream, InputDStream} @@ -207,18 +208,83 @@ class BatchCounter(ssc: StreamingContext) { } } -/** - * This is the base trait for Spark Streaming testsuites. This provides basic functionality - * to run user-defined set of input on user-defined stream operations, and verify the output. - */ -trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { - +trait ReuseableSparkContext extends SparkFunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context def framework: String = this.getClass.getSimpleName // Master for Spark context def master: String = "local[2]" + // Configurations to add to a new or existing spark context. + def extraSparkConf: Map[String, String] = { + Map("spark.streaming.stopSparkContextByDefault" -> (!reuseContext).toString) + } + + // Flag to indicate that the test should try to reuse a previously created context. + val reuseContext: Boolean = true + + // Spark context used during the tests. Note that this can be (re)used by several tests. + private var _sc: SparkContext = _ + + // Store the configuration of the SparkContext before testing. + private val confBeforeTesting: mutable.Buffer[(String, String)] = mutable.Buffer.empty + + // Get the existing or create a new spark context. + def sc: SparkContext = { + // Drop the existing context if we do not reuse. + if (!reuseContext) { + stopActiveContext() + } + + if (_sc == null || _sc.isStopped) { + val conf = new SparkConf().setMaster(master).setAppName(framework) + _sc = SparkContext.getOrCreate(conf) + + // Configure the context and make sure we store the old keys. + extraSparkConf.foreach { case (k, v) => + if (_sc.conf.contains(k) && reuseContext) { + confBeforeTesting += k -> _sc.conf.get(k) + } + _sc.conf.set(k, v) + } + } + _sc + } + + + override protected def beforeAll(): Unit = { + super.beforeAll() + // Drop the existing context if we do not reuse. + if (!reuseContext) { + stopActiveContext() + } + } + + protected override def afterAll(): Unit = { + if (_sc != null) { + if (reuseContext) { + extraSparkConf.foreach(kv => _sc.conf.remove(kv._1)) + _sc.conf.setAll(confBeforeTesting) + } + _sc = null + } + if (!reuseContext) { + stopActiveContext() + } + super.afterAll() + } + + protected def stopActiveContext(): Unit = { + SparkContext.getActiveContext().foreach(_.stop()) + _sc = null + } +} + +/** + * This is the base trait for Spark Streaming testsuites. This provides basic functionality + * to run user-defined set of input on user-defined stream operations, and verify the output. + */ +trait TestSuiteBase extends ReuseableSparkContext with BeforeAndAfter { // Batch duration def batchDuration: Duration = Seconds(1) @@ -235,37 +301,25 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { // Maximum time to wait before the test times out def maxWaitTimeMillis: Int = 10000 - // Whether to use manual clock or not - def useManualClock: Boolean = true + // Configurations to add to a new or existing spark context. + override def extraSparkConf: Map[String, String] = { + // Use a manual clock + super.extraSparkConf ++ Map("spark.streaming.clock" -> "org.apache.spark.util.ManualClock") + } // Whether to actually wait in real time before changing manual clock def actuallyWait: Boolean = false - // A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. - val conf = new SparkConf() - .setMaster(master) - .setAppName(framework) - // Timeout for use in ScalaTest `eventually` blocks val eventuallyTimeout: PatienceConfiguration.Timeout = timeout(Span(10, ScalaTestSeconds)) // Default before function for any streaming test suite. Override this // if you want to add your stuff to "before" (i.e., don't call before { } ) - def beforeFunction() { - if (useManualClock) { - logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - } else { - logInfo("Using real clock") - conf.set("spark.streaming.clock", "org.apache.spark.util.SystemClock") - } - } + def beforeFunction() { } // Default after function for any streaming test suite. Override this // if you want to add your stuff to "after" (i.e., don't call after { } ) - def afterFunction() { - System.clearProperty("spark.streaming.clock") - } + def afterFunction() { } before(beforeFunction) after(afterFunction) @@ -279,7 +333,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { block(ssc) } finally { try { - ssc.stop(stopSparkContext = true) + ssc.stop() } catch { case e: Exception => logError("Error stopping StreamingContext", e) @@ -287,6 +341,18 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { } } + def withStreamingContext[R](duration: Duration)(block: StreamingContext => R): R = { + withStreamingContext(new StreamingContext(sc, duration))(block) + } + + /** + * Run a block of code with a StreamingContext and automatically stop the context when the + * block completes or when an exception is thrown. + */ + def withStreamingContext[R](block: StreamingContext => R): R = { + withStreamingContext(new StreamingContext(sc, batchDuration))(block) + } + /** * Run a block of code with the given TestServer and automatically * stop the server when the block completes or when an exception is thrown. @@ -314,7 +380,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { numPartitions: Int = numInputPartitions ): StreamingContext = { // Create StreamingContext - val ssc = new StreamingContext(conf, batchDuration) + val ssc = new StreamingContext(sc, batchDuration) if (checkpointDir != null) { ssc.checkpoint(checkpointDir) } @@ -338,7 +404,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { operation: (DStream[U], DStream[V]) => DStream[W] ): StreamingContext = { // Create StreamingContext - val ssc = new StreamingContext(conf, batchDuration) + val ssc = new StreamingContext(sc, batchDuration) if (checkpointDir != null) { ssc.checkpoint(checkpointDir) } @@ -425,7 +491,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { Thread.sleep(100) // Give some time for the forgetting old RDDs to complete } finally { - ssc.stop(stopSparkContext = true) + ssc.stop() } output.asScala.toSeq } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index e7cec999c219..3732be12152a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -37,6 +37,8 @@ class UISeleniumSuite implicit var webDriver: WebDriver = _ + override val reuseContext: Boolean = false + override def beforeAll(): Unit = { super.beforeAll() webDriver = new HtmlUnitDriver { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index c7d085ec0799..face58f80cc7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -145,16 +145,16 @@ class WindowOperationsSuite extends TestSuiteBase { ) test("window - persistence level") { - val input = Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)) - val ssc = new StreamingContext(conf, batchDuration) - val inputStream = new TestInputStream[Int](ssc, input, 1) - val windowStream1 = inputStream.window(batchDuration * 2) - assert(windowStream1.storageLevel === StorageLevel.NONE) - assert(inputStream.storageLevel === StorageLevel.MEMORY_ONLY_SER) - windowStream1.persist(StorageLevel.MEMORY_ONLY) - assert(windowStream1.storageLevel === StorageLevel.NONE) - assert(inputStream.storageLevel === StorageLevel.MEMORY_ONLY) - ssc.stop() + withStreamingContext { ssc => + val input = Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)) + val inputStream = new TestInputStream[Int](ssc, input, 1) + val windowStream1 = inputStream.window(batchDuration * 2) + assert(windowStream1.storageLevel === StorageLevel.NONE) + assert(inputStream.storageLevel === StorageLevel.MEMORY_ONLY_SER) + windowStream1.persist(StorageLevel.MEMORY_ONLY) + assert(windowStream1.storageLevel === StorageLevel.NONE) + assert(inputStream.storageLevel === StorageLevel.MEMORY_ONLY) + } } // Testing naive reduceByKeyAndWindow (without invertible function) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala index e8c814ba7184..bc678fe0dbb6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala @@ -22,32 +22,24 @@ import java.io.File import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import org.scalatest.BeforeAndAfterAll - import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{State, Time} +import org.apache.spark.streaming.{ReuseableSparkContext, State, Time} import org.apache.spark.streaming.util.OpenHashMapBasedStateMap import org.apache.spark.util.Utils -class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll { +class MapWithStateRDDSuite extends ReuseableSparkContext with RDDCheckpointTester { - private var sc: SparkContext = null private var checkpointDir: File = _ override def beforeAll(): Unit = { super.beforeAll() - sc = new SparkContext( - new SparkConf().setMaster("local").setAppName("MapWithStateRDDSuite")) checkpointDir = Utils.createTempDir() sc.setCheckpointDir(checkpointDir.toString) } override def afterAll(): Unit = { try { - if (sc != null) { - sc.stop() - } Utils.deleteRecursively(checkpointDir) } finally { super.afterAll() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index a37fac87300b..8ef1010ab8ea 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -58,6 +58,7 @@ class WriteAheadLogBackedBlockRDDSuite override def beforeAll(): Unit = { super.beforeAll() + SparkContext.getActiveContext().foreach(_.stop()) sparkContext = new SparkContext(conf) blockManager = sparkContext.env.blockManager serializerManager = sparkContext.env.serializerManager diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala index b49e5790711c..9a5cf2ecf45a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.concurrent.Eventually.{eventually, timeout} import org.scalatest.mock.MockitoSugar import org.scalatest.time.SpanSugar._ -import org.apache.spark.{ExecutorAllocationClient, SparkConf, SparkFunSuite} +import org.apache.spark.{ExecutorAllocationClient, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.streaming.{DummyInputDStream, Seconds, StreamingContext} import org.apache.spark.util.{ManualClock, Utils} @@ -41,6 +41,12 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite before { allocationClient = mock[ExecutorAllocationClient] clock = new ManualClock() + SparkContext.stopActiveContext() + } + + protected override def afterAll(): Unit = { + SparkContext.stopActiveContext() + super.afterAll() } test("basic functionality") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index a7e365649d3e..dbcf8a57c690 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -17,27 +17,21 @@ package org.apache.spark.streaming.scheduler -import org.scalatest.BeforeAndAfter +import org.apache.spark.streaming._ -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.streaming.{Duration, StreamingContext, Time} - -class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { +class InputInfoTrackerSuite extends ReuseableSparkContext { private var ssc: StreamingContext = _ before { - val conf = new SparkConf().setMaster("local[2]").setAppName("DirectStreamTacker") - if (ssc == null) { - ssc = new StreamingContext(conf, Duration(1000)) - } + assert(ssc == null) + ssc = new StreamingContext(sc, Duration(1000)) } after { - if (ssc != null) { - ssc.stop() - ssc = null - } + assert(ssc != null) + ssc.stop() + ssc = null } test("test report and get InputInfo from InputInfoTracker") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala index 5f7f7fa5e67f..71aedfa5383d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -30,6 +30,10 @@ import org.apache.spark.util.{ManualClock, Utils} class JobGeneratorSuite extends TestSuiteBase { + override def extraSparkConf: Map[String, String] = super.extraSparkConf ++ Map( + "spark.streaming.clock" -> "org.apache.spark.streaming.util.ManualClock", + "spark.streaming.receiver.writeAheadLog.rollingInterval" -> "1") + // SPARK-6222 is a tricky regression bug which causes received block metadata // to be deleted before the corresponding batch has completed. This occurs when // the following conditions are met. @@ -59,11 +63,8 @@ class JobGeneratorSuite extends TestSuiteBase { test("SPARK-6222: Do not clear received block data too soon") { import JobGeneratorSuite._ val checkpointDir = Utils.createTempDir() - val testConf = conf - testConf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") - testConf.set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") - withStreamingContext(new StreamingContext(testConf, batchDuration)) { ssc => + withStreamingContext { ssc => val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val numBatches = 10 val longBatchNumber = 3 // 3rd batch will take a long time diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala index 37ca0ce2f6a3..24ac914ec074 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala @@ -25,13 +25,14 @@ import org.apache.spark.streaming.scheduler.rate.RateEstimator class RateControllerSuite extends TestSuiteBase { - override def useManualClock: Boolean = false + override def extraSparkConf: Map[String, String] = Map( + "spark.streaming.clock" -> "org.apache.spark.util.SystemClock", + "spark.streaming.stopSparkContextByDefault" -> "false") override def batchDuration: Duration = Milliseconds(50) test("RateController - rate controller publishes updates after batches complete") { - val ssc = new StreamingContext(conf, batchDuration) - withStreamingContext(ssc) { ssc => + withStreamingContext { ssc => val dstream = new RateTestInputDStream(ssc) dstream.register() ssc.start() @@ -43,8 +44,7 @@ class RateControllerSuite extends TestSuiteBase { } test("ReceiverRateController - published rates reach receivers") { - val ssc = new StreamingContext(conf, batchDuration) - withStreamingContext(ssc) { ssc => + withStreamingContext { ssc => val estimator = new ConstantEstimator(100) val dstream = new RateTestInputDStream(ssc) { override val rateController = diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index df122ac090c3..a8a6228e1818 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -32,8 +32,10 @@ import org.apache.spark.streaming.receiver._ /** Testsuite for receiver scheduling */ class ReceiverTrackerSuite extends TestSuiteBase { + override def batchDuration: Duration = Milliseconds(100) + test("send rate update to receivers") { - withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => + withStreamingContext { ssc => val newRateLimit = 100L val inputDStream = new RateTestInputDStream(ssc) val tracker = new ReceiverTracker(ssc) @@ -62,7 +64,7 @@ class ReceiverTrackerSuite extends TestSuiteBase { } test("should restart receiver after stopping it") { - withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => + withStreamingContext { ssc => @volatile var startTimes = 0 ssc.addStreamingListener(new StreamingListener { override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { @@ -82,10 +84,11 @@ class ReceiverTrackerSuite extends TestSuiteBase { } test("SPARK-11063: TaskSetManager should use Receiver RDD's preferredLocations") { - // Use ManualClock to prevent from starting batches so that we can make sure the only task is - // for starting the Receiver - val _conf = conf.clone.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - withStreamingContext(new StreamingContext(_conf, Milliseconds(100))) { ssc => + withStreamingContext { ssc => + // Make sure we use ManualClock to prevent from starting batches so that we can make sure the + // only task is for starting the Receiver + assert(ssc.sc.conf.get("spark.streaming.clock") == "org.apache.spark.util.ManualClock") + @volatile var receiverTaskLocality: TaskLocality = null ssc.sparkContext.addSparkListener(new SparkListener { override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { @@ -105,7 +108,7 @@ class ReceiverTrackerSuite extends TestSuiteBase { test("get allocated executors") { // Test get allocated executors when 1 receiver is registered - withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => + withStreamingContext { ssc => val input = ssc.receiverStream(new TestReceiver) val output = new TestOutputStream(input) output.register() @@ -114,7 +117,7 @@ class ReceiverTrackerSuite extends TestSuiteBase { } // Test get allocated executors when there's no receiver registered - withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => + withStreamingContext { ssc => val rdd = ssc.sc.parallelize(1 to 10) val input = new ConstantInputDStream(ssc, rdd) val output = new TestOutputStream(input)