diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f3da04a7f55d0..d0f52b9453b40 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -552,7 +552,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _cleaner = if (_conf.getBoolean("spark.cleaner.referenceTracking", true)) { - Some(new ContextCleaner(this)) + val cleaner = new ContextCleaner(this) + cleaner.attachListener(dagScheduler) + Some(cleaner) } else { None } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index daf9b0f95273e..dc9fe3a371f04 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -68,7 +68,7 @@ class DAGScheduler( blockManagerMaster: BlockManagerMaster, env: SparkEnv, clock: Clock = new SystemClock()) - extends Logging { + extends Logging with CleanerListener { def this(sc: SparkContext, taskScheduler: TaskScheduler) = { this( @@ -483,9 +483,6 @@ class DAGScheduler( logDebug("Removing running stage %d".format(stageId)) runningStages -= stage } - for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) { - shuffleToMapStage.remove(k) - } if (waitingStages.contains(stage)) { logDebug("Removing stage %d from waiting set.".format(stageId)) waitingStages -= stage @@ -496,9 +493,12 @@ class DAGScheduler( } } // data structures based on StageId - stageIdToStage -= stageId - logDebug("After removal of stage %d, remaining stages = %d" - .format(stageId, stageIdToStage.size)) + // ShuffleMapStages aren't removed until the shuffle is cleaned + if (stage.isInstanceOf[ResultStage]) { + stageIdToStage -= stageId + logDebug("After removal of stage %d, remaining stages = %d" + .format(stageId, stageIdToStage.size)) + } } jobSet -= job.jobId @@ -1436,6 +1436,21 @@ class DAGScheduler( taskScheduler.stop() } + /** + * Called by the context cleaner when a shuffle is removed + * @param shuffleId + */ + override def shuffleCleaned(shuffleId: Int): Unit = { + val stageOpt = shuffleToMapStage.remove(shuffleId) + stageOpt.foreach { stage => stageIdToStage -= stage.id} + } + + // These are all called by the context cleaner but we don't need them + override def accumCleaned(accId: Long): Unit = {} + override def broadcastCleaned(broadcastId: Long): Unit = {} + override def checkpointCleaned(rddId: Long): Unit = {} + override def rddCleaned(rddId: Int): Unit = {} + // Start the event thread and register the metrics source at the end of the constructor env.metricsSystem.registerSource(metricsSource) eventProcessLoop.start() diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2e8688cf41d99..d5ee24414073b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal @@ -163,7 +164,7 @@ class DAGSchedulerSuite cancelledStages.clear() cacheLocations.clear() results.clear() - mapOutputTracker = new MapOutputTrackerMaster(conf) + mapOutputTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] scheduler = new DAGScheduler( sc, taskScheduler, @@ -171,6 +172,9 @@ class DAGSchedulerSuite mapOutputTracker, blockManagerMaster, sc.env) + // this is normally done in the SparkContext creation, but since we are ignoring the + // scheduler in the SparkContext and creating our own, need to re-register here + sc.cleaner.foreach{_.attachListener(scheduler)} dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) } @@ -695,6 +699,199 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + + // Helper function to validate state when creating tests for task failures + def checkStageId(stageId: Int, attempt: Int, stageAttempt: TaskSet) { + assert(stageAttempt.stageId === stageId, + s": expected stage $stageId, instead was ${stageAttempt.stageId}") + assert(stageAttempt.stageAttemptId == attempt, + s": expected stage attempt $attempt, instead was ${stageAttempt.stageAttemptId}") + } + + def makeCompletions(stageAttempt: TaskSet, reduceParts: Int): Seq[(Success.type, MapStatus)] = { + stageAttempt.tasks.zipWithIndex.map { case (task, idx) => + (Success, makeMapStatus("host" + ('A' + idx).toChar, reduceParts)) + }.toSeq + } + + def setupStageAbortTest(sc: SparkContext) { + sc.listenerBus.addListener(new EndListener()) + ended = false + jobResult = null + } + + // Create a new Listener to confirm that the listenerBus sees the JobEnd message + // when we abort the stage. This message will also be consumed by the EventLoggingListener + // so this will propagate up to the user. + var ended = false + var jobResult : JobResult = null + + class EndListener extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + jobResult = jobEnd.jobResult + ended = true + } + } + + // Helper functions to extract commonly used code in Fetch Failure test cases + /** + * Common code to get the next stage attempt, confirm it's the one we expect, and complete it + * succesfullly. + * + * @param stageId - The current stageId + * @param attemptIdx - The current attempt count + * @param numShufflePartitions - The number of partitions in the next stage + */ + def completeNextShuffleMapSuccesfully( + stageId: Int, + attemptIdx: Int, + numShufflePartitions: Int): Unit = { + val stageAttempt = taskSets.last + checkStageId(stageId, attemptIdx, stageAttempt) + complete(stageAttempt, makeCompletions(stageAttempt, numShufflePartitions)) + } + + /** + * Common code to get the next stage attempt, confirm it's the one we expect, and complete it + * with all FetchFailure. + * + * @param stageId - The current stageId + * @param attemptIdx - The current attempt count + * @param shuffleDep - The shuffle dependency of the stage with a fetch failure + */ + def completeNextStageWithFetchFailure( + stageId: Int, + attemptIdx: Int, + shuffleDep: ShuffleDependency[_, _, _]): Unit = { + val stageAttempt = taskSets.last + checkStageId(stageId, attemptIdx, stageAttempt) + + complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map{ case (task, idx) => + (FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0, idx, "ignored"), null) + }.toSeq) + } + + /** + * Common code to get the next result stage attempt, confirm it's the one we expect, and + * complete it with a success where we return 42. + * + * @param stageId - The current stageId + * @param attemptIdx - The current attempt count + */ + def completeNextResultStageWithSuccess ( + stageId: Int, + attemptIdx: Int, + resultFunc: Int => Int = _ => 42): Unit = { + val stageAttempt = taskSets.last + checkStageId(stageId, attemptIdx, stageAttempt) + assert(scheduler.stageIdToStage(stageId).isInstanceOf[ResultStage]) + complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map { case (_, idx) => + (Success, resultFunc(idx)) + }.toSeq) + } + + + + test("shuffle fetch failure in a reused shuffle dependency") { + // Run the first job successfully, which creates one shuffle dependency + + val jobIdToStageIds = new mutable.HashMap[Int, Set[Int]]() + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobIdToStageIds(jobStart.jobId) = jobStart.stageIds.toSet + } + } + sc.addSparkListener(listener) + + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + sc.listenerBus.waitUntilEmpty(1000) + assert(jobIdToStageIds(0) === Set(0, 1)) + + completeNextShuffleMapSuccesfully(0, 0, 2) + completeNextResultStageWithSuccess(1, 0) + assert(results === Map(0 -> 42, 1 -> 42)) + assertDataStructuresEmpty() + + // submit another job w/ the shared dependency, and have a fetch failure + val reduce2 = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduce2, Array(0, 1)) + sc.listenerBus.waitUntilEmpty(1000) + assert(jobIdToStageIds(1) === Set(0, 2)) + completeNextStageWithFetchFailure(2, 0, shuffleDep) + scheduler.resubmitFailedStages() + completeNextShuffleMapSuccesfully(0, 1, 2) + completeNextResultStageWithSuccess(2, 1, idx => idx + 1234) + assert(results === Map(0 -> 1234, 1 -> 1235)) + + assertDataStructuresEmpty() + assertEmptyAfterContextCleaner() + } + + + test("reused dependency with long lineage") { + val jobIdToStageIds = new mutable.HashMap[Int, Set[Int]]() + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobIdToStageIds(jobStart.jobId) = jobStart.stageIds.toSet + } + } + sc.addSparkListener(listener) + + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd, null) + val reduceRdd = new MyRDD(sc, 4, List(shuffleDep1)) + val shuffleDep2 = new ShuffleDependency(reduceRdd, null) + val reduceRdd2 = new MyRDD(sc, 6, List(shuffleDep2)) + val shuffleDep3 = new ShuffleDependency(reduceRdd2, null) + val reduceRdd3 = new MyRDD(sc, 8, List(shuffleDep3)) + submit(reduceRdd3, (0 until 8).toArray) + sc.listenerBus.waitUntilEmpty(1000) + assert(jobIdToStageIds(0) === Set(0, 1, 2, 3)) + + completeNextShuffleMapSuccesfully(0, 0, 4) + completeNextShuffleMapSuccesfully(1, 0, 6) + completeNextShuffleMapSuccesfully(2, 0, 8) + completeNextResultStageWithSuccess(3, 0) + assert(results === (0 until 8).map{_ -> 42}.toMap) + results.clear() + assertDataStructuresEmpty() + + // submit another job w/ the shared dependency, and have a fetch failure + val reduce4 = new MyRDD(sc, 2, List(shuffleDep3)) + submit(reduce4, Array(0, 1)) + sc.listenerBus.waitUntilEmpty(1000) + assert(jobIdToStageIds(1) === Set(0, 1, 2, 4)) + completeNextStageWithFetchFailure(4, 0, shuffleDep3) + scheduler.resubmitFailedStages() + completeNextShuffleMapSuccesfully(0, 1, 4) + completeNextShuffleMapSuccesfully(1, 1, 6) + completeNextShuffleMapSuccesfully(2, 1, 8) + completeNextResultStageWithSuccess(4, 1, idx => idx + 1234) + assert(results === Map(0 -> 1234, 1 -> 1235)) + results.clear() + assertDataStructuresEmpty() + assertEmptyAfterContextCleaner() + + // now try submitting again, after we've cleaned out the shuffle data. Should be fine, + // we just need to rerun everything + + val reduce5 = new MyRDD(sc, 8, List(shuffleDep3)) + submit(reduce5, (0 until 8).toArray) + sc.listenerBus.waitUntilEmpty(1000) + assert(jobIdToStageIds(2) === Set(5, 6, 7, 8)) // new stages this time + completeNextShuffleMapSuccesfully(5, 0, 4) + completeNextShuffleMapSuccesfully(6, 0, 6) + completeNextShuffleMapSuccesfully(7, 0, 8) + completeNextResultStageWithSuccess(8, 0, idx => idx + 4321) + assert(results === (0 until 8).map{idx => idx -> (idx + 4321)}.toMap) + + assertDataStructuresEmpty() + assertEmptyAfterContextCleaner() + } + /** * Makes sure that failures of stage used by multiple jobs are correctly handled. * @@ -1012,13 +1209,19 @@ class DAGSchedulerSuite assert(scheduler.failedStages.isEmpty) assert(scheduler.jobIdToActiveJob.isEmpty) assert(scheduler.jobIdToStageIds.isEmpty) - assert(scheduler.stageIdToStage.isEmpty) assert(scheduler.runningStages.isEmpty) - assert(scheduler.shuffleToMapStage.isEmpty) assert(scheduler.waitingStages.isEmpty) assert(scheduler.outputCommitCoordinator.isEmpty) } + private def assertEmptyAfterContextCleaner(): Unit = { + scheduler.shuffleToMapStage.foreach { case (shuffleId, _) => + sc.cleaner.get.doCleanupShuffle(shuffleId, blocking = true) + } + assert(scheduler.stageIdToStage.isEmpty) + assert(scheduler.shuffleToMapStage.isEmpty) + } + // Nothing in this test should break if the task info's fields are null, but // OutputCommitCoordinator requires the task info itself to not be null. private def createFakeTaskInfo(): TaskInfo = {