diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8c46a8432339..78b6b34b5d2b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1092,17 +1092,16 @@ class DAGScheduler( // the stage as completed here in case there are no tasks to run markStageAsFinished(stage, None) - val debugString = stage match { + stage match { case stage: ShuffleMapStage => - s"Stage ${stage} is actually done; " + - s"(available: ${stage.isAvailable}," + - s"available outputs: ${stage.numAvailableOutputs}," + - s"partitions: ${stage.numPartitions})" + logDebug(s"Stage ${stage} is actually done; " + + s"(available: ${stage.isAvailable}," + + s"available outputs: ${stage.numAvailableOutputs}," + + s"partitions: ${stage.numPartitions})") + markMapStageJobsAsFinished(stage) case stage : ResultStage => - s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})" + logDebug(s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})") } - logDebug(debugString) - submitWaitingChildStages(stage) } } @@ -1307,13 +1306,7 @@ class DAGScheduler( shuffleStage.findMissingPartitions().mkString(", ")) submitStage(shuffleStage) } else { - // Mark any map-stage jobs waiting on this stage as finished - if (shuffleStage.mapStageJobs.nonEmpty) { - val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) - for (job <- shuffleStage.mapStageJobs) { - markMapStageJobAsFinished(job, stats) - } - } + markMapStageJobsAsFinished(shuffleStage) submitWaitingChildStages(shuffleStage) } } @@ -1433,6 +1426,16 @@ class DAGScheduler( } } + private[scheduler] def markMapStageJobsAsFinished(shuffleStage: ShuffleMapStage): Unit = { + // Mark any map-stage jobs waiting on this stage as finished + if (shuffleStage.isAvailable && shuffleStage.mapStageJobs.nonEmpty) { + val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) + for (job <- shuffleStage.mapStageJobs) { + markMapStageJobAsFinished(job, stats) + } + } + } + /** * Responds to an executor being lost. This is called inside the event loop, so it assumes it can * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d812b5bd92c1..8b6ec37625ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2146,6 +2146,58 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } + test("Trigger mapstage's job listener in submitMissingTasks") { + val rdd1 = new MyRDD(sc, 2, Nil) + val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(2)) + val rdd2 = new MyRDD(sc, 2, List(dep1), tracker = mapOutputTracker) + val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(2)) + + val listener1 = new SimpleListener + val listener2 = new SimpleListener + + submitMapStage(dep1, listener1) + submitMapStage(dep2, listener2) + + // Complete the stage0. + assert(taskSets(0).stageId === 0) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", rdd1.partitions.length)), + (Success, makeMapStatus("hostB", rdd1.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(listener1.results.size === 1) + + // When attempting stage1, trigger a fetch failure. + assert(taskSets(1).stageId === 1) + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.length)), + (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + // Stage1 listener should not have a result yet + assert(listener2.results.size === 0) + + // Speculative task succeeded in stage1. + runEvent(makeCompletionEvent( + taskSets(1).tasks(1), + Success, + makeMapStatus("hostD", rdd2.partitions.length))) + // stage1 listener still should not have a result, though there's no missing partitions + // in it. Because stage1 has been failed and is not inside `runningStages` at this moment. + assert(listener2.results.size === 0) + + // Stage0 should now be running as task set 2; make its task succeed + assert(taskSets(2).stageId === 0) + complete(taskSets(2), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + Set(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + + // After stage0 is finished, stage1 will be submitted and found there is no missing + // partitions in it. Then listener got triggered. + assert(listener2.results.size === 1) + assertDataStructuresEmpty() + } + /** * In this test, we run a map stage where one of the executors fails but we still receive a * "zombie" complete message from that executor. We want to make sure the stage is not reported