diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index aee92ba928b4..baf0ed4df530 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1552,6 +1552,26 @@ private[spark] class DAGScheduler( // `findMissingPartitions()` returns all partitions every time. stage match { case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => + // already executed at least once + if (sms.getNextAttemptId > 0) { + // While we previously validated possible rollbacks during the handling of a FetchFailure, + // where we were fetching from an indeterminate source map stages, this later check + // covers additional cases like recalculating an indeterminate stage after an executor + // loss. Moreover, because this check occurs later in the process, if a result stage task + // has successfully completed, we can detect this and abort the job, as rolling back a + // result stage is not possible. + val stagesToRollback = collectSucceedingStages(sms) + abortStageWithInvalidRollBack(stagesToRollback) + // stages which cannot be rolled back were aborted which leads to removing the + // the dependant job(s) from the active jobs set + val numActiveJobsWithStageAfterRollback = + activeJobs.count(job => stagesToRollback.contains(job.finalStage)) + if (numActiveJobsWithStageAfterRollback == 0) { + logInfo(log"All jobs depending on the indeterminate stage " + + log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.") + return + } + } mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) sms.shuffleDep.newShuffleMergeState() case _ => @@ -2129,60 +2149,8 @@ private[spark] class DAGScheduler( // guaranteed to be determinate, so the input data of the reducers will not change // even if the map tasks are re-tried. if (mapStage.isIndeterminate) { - // It's a little tricky to find all the succeeding stages of `mapStage`, because - // each stage only know its parents not children. Here we traverse the stages from - // the leaf nodes (the result stages of active jobs), and rollback all the stages - // in the stage chains that connect to the `mapStage`. To speed up the stage - // traversing, we collect the stages to rollback first. If a stage needs to - // rollback, all its succeeding stages need to rollback to. - val stagesToRollback = HashSet[Stage](mapStage) - - def collectStagesToRollback(stageChain: List[Stage]): Unit = { - if (stagesToRollback.contains(stageChain.head)) { - stageChain.drop(1).foreach(s => stagesToRollback += s) - } else { - stageChain.head.parents.foreach { s => - collectStagesToRollback(s :: stageChain) - } - } - } - - def generateErrorMessage(stage: Stage): String = { - "A shuffle map stage with indeterminate output was failed and retried. " + - s"However, Spark cannot rollback the $stage to re-process the input data, " + - "and has to fail this job. Please eliminate the indeterminacy by " + - "checkpointing the RDD before repartition and try again." - } - - activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil)) - - // The stages will be rolled back after checking - val rollingBackStages = HashSet[Stage](mapStage) - stagesToRollback.foreach { - case mapStage: ShuffleMapStage => - val numMissingPartitions = mapStage.findMissingPartitions().length - if (numMissingPartitions < mapStage.numTasks) { - if (sc.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { - val reason = "A shuffle map stage with indeterminate output was failed " + - "and retried. However, Spark can only do this while using the new " + - "shuffle block fetching protocol. Please check the config " + - "'spark.shuffle.useOldFetchProtocol', see more detail in " + - "SPARK-27665 and SPARK-25341." - abortStage(mapStage, reason, None) - } else { - rollingBackStages += mapStage - } - } - - case resultStage: ResultStage if resultStage.activeJob.isDefined => - val numMissingPartitions = resultStage.findMissingPartitions().length - if (numMissingPartitions < resultStage.numTasks) { - // TODO: support to rollback result tasks. - abortStage(resultStage, generateErrorMessage(resultStage), None) - } - - case _ => - } + val stagesToRollback = collectSucceedingStages(mapStage) + val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback) logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output was failed, " + log"we will roll back and rerun below stages which include itself and all its " + log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}") @@ -2346,6 +2314,74 @@ private[spark] class DAGScheduler( } } + private def collectSucceedingStages(mapStage: ShuffleMapStage): HashSet[Stage] = { + // TODO: perhaps materialize this if we are going to compute it often enough ? + // It's a little tricky to find all the succeeding stages of `mapStage`, because + // each stage only know its parents not children. Here we traverse the stages from + // the leaf nodes (the result stages of active jobs), and rollback all the stages + // in the stage chains that connect to the `mapStage`. To speed up the stage + // traversing, we collect the stages to rollback first. If a stage needs to + // rollback, all its succeeding stages need to rollback to. + val succeedingStages = HashSet[Stage](mapStage) + + def collectSucceedingStagesInternal(stageChain: List[Stage]): Unit = { + if (succeedingStages.contains(stageChain.head)) { + stageChain.drop(1).foreach(s => succeedingStages += s) + } else { + stageChain.head.parents.foreach { s => + collectSucceedingStagesInternal(s :: stageChain) + } + } + } + activeJobs.foreach(job => collectSucceedingStagesInternal(job.finalStage :: Nil)) + succeedingStages + } + + /** + * Abort stages where roll back is requested but cannot be completed. + * + * @param stagesToRollback stages to roll back + * @return Shuffle map stages which need and can be rolled back + */ + private def abortStageWithInvalidRollBack(stagesToRollback: HashSet[Stage]): HashSet[Stage] = { + + def generateErrorMessage(stage: Stage): String = { + "A shuffle map stage with indeterminate output was failed and retried. " + + s"However, Spark cannot rollback the $stage to re-process the input data, " + + "and has to fail this job. Please eliminate the indeterminacy by " + + "checkpointing the RDD before repartition and try again." + } + + // The stages will be rolled back after checking + val rollingBackStages = HashSet[Stage]() + stagesToRollback.foreach { + case mapStage: ShuffleMapStage => + if (mapStage.numAvailableOutputs > 0) { + if (sc.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { + val reason = "A shuffle map stage with indeterminate output was failed " + + "and retried. However, Spark can only do this while using the new " + + "shuffle block fetching protocol. Please check the config " + + "'spark.shuffle.useOldFetchProtocol', see more detail in " + + "SPARK-27665 and SPARK-25341." + abortStage(mapStage, reason, None) + } else { + rollingBackStages += mapStage + } + } + + case resultStage: ResultStage if resultStage.activeJob.isDefined => + val numMissingPartitions = resultStage.findMissingPartitions().length + if (numMissingPartitions < resultStage.numTasks) { + // TODO: support to rollback result tasks. + abortStage(resultStage, generateErrorMessage(resultStage), None) + } + + case _ => + } + + rollingBackStages + } + /** * Whether executor is decommissioning or decommissioned. * Return true when: diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 3e507df706ba..d4e90be7c66d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.scheduler import java.util.{ArrayList => JArrayList, Collections => JCollections, Properties} -import java.util.concurrent.{CountDownLatch, Delayed, ScheduledFuture, TimeUnit} +import java.util.concurrent.{CountDownLatch, Delayed, LinkedBlockingQueue, ScheduledFuture, TimeUnit} import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong, AtomicReference} import scala.annotation.meta.param -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.jdk.CollectionConverters._ import scala.language.reflectiveCalls import scala.util.control.NonFatal @@ -56,28 +56,31 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) dagScheduler.setEventProcessLoop(this) - private var isProcessing = false - private val eventQueue = new ListBuffer[DAGSchedulerEvent]() - + private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]() override def post(event: DAGSchedulerEvent): Unit = { - if (isProcessing) { - // `DAGSchedulerEventProcessLoop` is guaranteed to process events sequentially. So we should - // buffer events for sequent processing later instead of processing them recursively. - eventQueue += event - } else { - try { - isProcessing = true - // Forward event to `onReceive` directly to avoid processing event asynchronously. - onReceive(event) - } catch { - case NonFatal(e) => onError(e) - } finally { - isProcessing = false - } - if (eventQueue.nonEmpty) { - post(eventQueue.remove(0)) - } + // `DAGSchedulerEventProcessLoop` is guaranteed to process events sequentially in the main test + // thread similarly as it is done in production using the "dag-scheduler-event-loop". + // So we should buffer events for sequent processing later instead of executing them + // on thread calling post() (which might be the "dag-scheduler-message" thread for some + // events posted by the DAGScheduler itself) + eventQueue.put(event) + } + + def runEvents(): Unit = { + var dagEvent = eventQueue.poll() + while (dagEvent != null) { + onReciveWithErrorHandler(dagEvent) + dagEvent = eventQueue.poll() + } + } + + private def onReciveWithErrorHandler(event: DAGSchedulerEvent): Unit = { + try { + // Forward event to `onReceive` directly to avoid processing event asynchronously. + onReceive(event) + } catch { + case NonFatal(e) => onError(e) } } @@ -306,7 +309,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti var broadcastManager: BroadcastManager = null var securityMgr: SecurityManager = null var scheduler: DAGScheduler = null - var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null + var dagEventProcessLoopTester: DAGSchedulerEventProcessLoopTester = null /** * Set of cache locations to return from our mock BlockManagerMaster. @@ -479,6 +482,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti // Ensure the initialization of various components sc dagEventProcessLoopTester.post(event) + dagEventProcessLoopTester.runEvents() } /** @@ -1190,11 +1194,12 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti private def completeNextStageWithFetchFailure( stageId: Int, attemptIdx: Int, - shuffleDep: ShuffleDependency[_, _, _]): Unit = { + shuffleDep: ShuffleDependency[_, _, _], + srcHost: String = "hostA"): Unit = { val stageAttempt = taskSets.last checkStageId(stageId, attemptIdx, stageAttempt) complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map { case (task, idx) => - (FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0L, 0, idx, "ignored"), null) + (FetchFailed(makeBlockManagerId(srcHost), shuffleDep.shuffleId, 0L, 0, idx, "ignored"), null) }.toSeq) } @@ -2251,6 +2256,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(completedStage === List(0, 1)) Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2) + dagEventProcessLoopTester.runEvents() // map stage resubmitted assert(scheduler.runningStages.size === 1) val mapStage = scheduler.runningStages.head @@ -2286,6 +2292,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti sc.listenerBus.waitUntilEmpty() Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2) + dagEventProcessLoopTester.runEvents() // map stage is running by resubmitted, result stage is waiting // map tasks and the origin result task 1.0 are running assert(scheduler.runningStages.size == 1, "Map stage should be running") @@ -3125,6 +3132,92 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(countSubmittedMapStageAttempts() === 2) } + /** + * This function creates the following dependency graph: + * + * (determinate) (indeterminate) + * shuffleMapRdd0 shuffleMapRDD1 + * \ / + * \ / + * finalRdd + * + * Both ShuffleMapRdds will be ShuffleMapStages with 2 partitions executed on + * hostA_exec and hostB_exec. + */ + def constructMixedDeterminateDependencies(): + (ShuffleDependency[_, _, _], ShuffleDependency[_, _, _]) = { + val numPartitions = 2 + val shuffleMapRdd0 = new MyRDD(sc, numPartitions, Nil, indeterminate = false) + val shuffleDep0 = new ShuffleDependency(shuffleMapRdd0, new HashPartitioner(2)) + + val shuffleMapRdd1 = + new MyRDD(sc, numPartitions, Nil, tracker = mapOutputTracker, indeterminate = true) + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) + + val finalRdd = + new MyRDD(sc, numPartitions, List(shuffleDep0, shuffleDep1), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + // Finish the first shuffle map stage. + completeShuffleMapStageSuccessfully(0, 0, numPartitions, Seq("hostA", "hostB")) + completeShuffleMapStageSuccessfully(1, 0, numPartitions, Seq("hostA", "hostB")) + assert(mapOutputTracker.findMissingPartitions(0) === Some(Seq.empty)) + assert(mapOutputTracker.findMissingPartitions(1) === Some(Seq.empty)) + + (shuffleDep0, shuffleDep1) + } + + test("SPARK-51272: re-submit of an indeterminate stage without partial result can succeed") { + val shuffleDeps = constructMixedDeterminateDependencies() + val resultStage = scheduler.stageIdToStage(2).asInstanceOf[ResultStage] + + // the fetch failure is from the determinate shuffle map stage but this leads to + // executor lost and removing the shuffle files generated by the indeterminate stage too + completeNextStageWithFetchFailure(resultStage.id, 0, shuffleDeps._1, "hostA") + + Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2) + dagEventProcessLoopTester.runEvents() + assert(scheduler.runningStages.size === 2) + assert(scheduler.runningStages.forall(_.isInstanceOf[ShuffleMapStage])) + + completeShuffleMapStageSuccessfully(0, 1, 2, Seq("hostA", "hostB")) + completeShuffleMapStageSuccessfully(1, 1, 2, Seq("hostA", "hostB")) + assert(scheduler.runningStages.size === 1) + assert(scheduler.runningStages.head === resultStage) + assert(resultStage.latestInfo.failureReason.isEmpty) + + completeNextResultStageWithSuccess(resultStage.id, 1) + } + + test("SPARK-51272: re-submit of an indeterminate stage with partial result will fail") { + val shuffleDeps = constructMixedDeterminateDependencies() + val resultStage = scheduler.stageIdToStage(2).asInstanceOf[ResultStage] + + runEvent(makeCompletionEvent(taskSets(2).tasks(0), Success, 42)) + // the fetch failure is from the determinate shuffle map stage but this leads to + // executor lost and removing the shuffle files generated by the indeterminate stage too + runEvent(makeCompletionEvent( + taskSets(2).tasks(1), + FetchFailed(makeBlockManagerId("hostA"), shuffleDeps._1.shuffleId, 0L, 0, 0, "ignored"), + null)) + + dagEventProcessLoopTester.runEvents() + // resubmission has not yet happened, so job is still running + assert(scheduler.activeJobs.nonEmpty) + Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2) + dagEventProcessLoopTester.runEvents() + + // all dependent jobs have been failed + assert(scheduler.runningStages.size === 0) + assert(scheduler.activeJobs.isEmpty) + assert(resultStage.latestInfo.failureReason.isDefined) + assert(resultStage.latestInfo.failureReason.get. + contains("A shuffle map stage with indeterminate output was failed and retried. " + + "However, Spark cannot rollback the ResultStage")) + assert(scheduler.activeJobs.isEmpty, "Aborting the stage aborts the job as well.") + } + private def constructIndeterminateStageFetchFailed(): (Int, Int) = { val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) @@ -4884,6 +4977,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti // wait resubmit sc.listenerBus.waitUntilEmpty() Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2) + dagEventProcessLoopTester.runEvents() // stage0 retry val stage0Retry = taskSets.filter(_.stageId == 1) @@ -4984,6 +5078,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti // the stages will now get resubmitted due to the failure Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2) + dagEventProcessLoopTester.runEvents() // parent map stage resubmitted assert(scheduler.runningStages.size === 1) @@ -5003,6 +5098,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti result = MapStatus(BlockManagerId("hostF-exec1", "hostF", 12345), Array.fill[Long](2)(2), mapTaskId = taskIdCount))) Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2) + dagEventProcessLoopTester.runEvents() // The retries should succeed sc.listenerBus.waitUntilEmpty() @@ -5012,6 +5108,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti // This will add 3 new stages. submit(reduceRdd, Array(0, 1)) Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2) + dagEventProcessLoopTester.runEvents() // Only the last stage needs to execute, and those tasks - so completed stages should not // change.