Skip to content
Closed
136 changes: 82 additions & 54 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,20 @@ private[spark] class DAGScheduler(
// `findMissingPartitions()` returns all partitions every time.
stage match {
case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
// already executed atleast once
if (sms.getNextAttemptId > 0) {
val stagesToRollback = collectSucceedingStages(sms)
rollBackStages(stagesToRollback)
// stages which cannot be rolled back were aborted which leads to removing the
// the dependant job(s) from the active jobs set
val numActiveJobsWithStageAfterRollback =
activeJobs.count(job => stagesToRollback.contains(job.finalStage))
if (numActiveJobsWithStageAfterRollback == 0) {
logInfo(log"All jobs depending on the indeterminate stage " +
log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.")
return
}
}
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
sms.shuffleDep.newShuffleMergeState()
case _ =>
Expand Down Expand Up @@ -2129,60 +2143,8 @@ private[spark] class DAGScheduler(
// guaranteed to be determinate, so the input data of the reducers will not change
// even if the map tasks are re-tried.
if (mapStage.isIndeterminate) {
// It's a little tricky to find all the succeeding stages of `mapStage`, because
// each stage only know its parents not children. Here we traverse the stages from
// the leaf nodes (the result stages of active jobs), and rollback all the stages
// in the stage chains that connect to the `mapStage`. To speed up the stage
// traversing, we collect the stages to rollback first. If a stage needs to
// rollback, all its succeeding stages need to rollback to.
val stagesToRollback = HashSet[Stage](mapStage)

def collectStagesToRollback(stageChain: List[Stage]): Unit = {
if (stagesToRollback.contains(stageChain.head)) {
stageChain.drop(1).foreach(s => stagesToRollback += s)
} else {
stageChain.head.parents.foreach { s =>
collectStagesToRollback(s :: stageChain)
}
}
}

def generateErrorMessage(stage: Stage): String = {
"A shuffle map stage with indeterminate output was failed and retried. " +
s"However, Spark cannot rollback the $stage to re-process the input data, " +
"and has to fail this job. Please eliminate the indeterminacy by " +
"checkpointing the RDD before repartition and try again."
}

activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil))

// The stages will be rolled back after checking
val rollingBackStages = HashSet[Stage](mapStage)
stagesToRollback.foreach {
case mapStage: ShuffleMapStage =>
val numMissingPartitions = mapStage.findMissingPartitions().length
if (numMissingPartitions < mapStage.numTasks) {
if (sc.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) {
val reason = "A shuffle map stage with indeterminate output was failed " +
"and retried. However, Spark can only do this while using the new " +
"shuffle block fetching protocol. Please check the config " +
"'spark.shuffle.useOldFetchProtocol', see more detail in " +
"SPARK-27665 and SPARK-25341."
abortStage(mapStage, reason, None)
} else {
rollingBackStages += mapStage
}
}

case resultStage: ResultStage if resultStage.activeJob.isDefined =>
val numMissingPartitions = resultStage.findMissingPartitions().length
if (numMissingPartitions < resultStage.numTasks) {
// TODO: support to rollback result tasks.
abortStage(resultStage, generateErrorMessage(resultStage), None)
}

case _ =>
}
val stagesToRollback = collectSucceedingStages(mapStage)
val rollingBackStages = rollBackStages(stagesToRollback)
logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output was failed, " +
log"we will roll back and rerun below stages which include itself and all its " +
log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}")
Expand Down Expand Up @@ -2346,6 +2308,72 @@ private[spark] class DAGScheduler(
}
}

private def collectSucceedingStages(mapStage: ShuffleMapStage): HashSet[Stage] = {
// TODO: perhaps materialize this if we are going to compute it often enough ?
// It's a little tricky to find all the succeeding stages of `mapStage`, because
// each stage only know its parents not children. Here we traverse the stages from
// the leaf nodes (the result stages of active jobs), and rollback all the stages
// in the stage chains that connect to the `mapStage`. To speed up the stage
// traversing, we collect the stages to rollback first. If a stage needs to
// rollback, all its succeeding stages need to rollback to.
val succeedingStages = HashSet[Stage](mapStage)

def collectSucceedingStagesInternal(stageChain: List[Stage]): Unit = {
if (succeedingStages.contains(stageChain.head)) {
stageChain.drop(1).foreach(s => succeedingStages += s)
} else {
stageChain.head.parents.foreach { s =>
collectSucceedingStagesInternal(s :: stageChain)
}
}
}
activeJobs.foreach(job => collectSucceedingStagesInternal(job.finalStage :: Nil))
succeedingStages
}

/**
* @param stagesToRollback stages to roll back
* @return Shuffle map stages which need and can be rolled back
*/
private def rollBackStages(stagesToRollback: HashSet[Stage]): HashSet[Stage] = {

def generateErrorMessage(stage: Stage): String = {
"A shuffle map stage with indeterminate output was failed and retried. " +
s"However, Spark cannot rollback the $stage to re-process the input data, " +
"and has to fail this job. Please eliminate the indeterminacy by " +
"checkpointing the RDD before repartition and try again."
}

// The stages will be rolled back after checking
val rollingBackStages = HashSet[Stage]()
stagesToRollback.foreach {
case mapStage: ShuffleMapStage =>
if (mapStage.numAvailableOutputs > 0) {
if (sc.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) {
val reason = "A shuffle map stage with indeterminate output was failed " +
"and retried. However, Spark can only do this while using the new " +
"shuffle block fetching protocol. Please check the config " +
"'spark.shuffle.useOldFetchProtocol', see more detail in " +
"SPARK-27665 and SPARK-25341."
abortStage(mapStage, reason, None)
} else {
rollingBackStages += mapStage
}
}

case resultStage: ResultStage if resultStage.activeJob.isDefined =>
val numMissingPartitions = resultStage.findMissingPartitions().length
if (numMissingPartitions < resultStage.numTasks) {
// TODO: support to rollback result tasks.
abortStage(resultStage, generateErrorMessage(resultStage), None)
}

case _ =>
}

rollingBackStages
}

/**
* Whether executor is decommissioning or decommissioned.
* Return true when:
Expand Down
150 changes: 125 additions & 25 deletions core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
package org.apache.spark.scheduler

import java.util.{ArrayList => JArrayList, Collections => JCollections, Properties}
import java.util.concurrent.{CountDownLatch, Delayed, ScheduledFuture, TimeUnit}
import java.util.concurrent.{CountDownLatch, Delayed, LinkedBlockingQueue, ScheduledFuture, TimeUnit}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong, AtomicReference}

import scala.annotation.meta.param
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.jdk.CollectionConverters._
import scala.language.reflectiveCalls
import scala.util.control.NonFatal
Expand Down Expand Up @@ -56,28 +56,31 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)

dagScheduler.setEventProcessLoop(this)

private var isProcessing = false
private val eventQueue = new ListBuffer[DAGSchedulerEvent]()

private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]()

override def post(event: DAGSchedulerEvent): Unit = {
if (isProcessing) {
// `DAGSchedulerEventProcessLoop` is guaranteed to process events sequentially. So we should
// buffer events for sequent processing later instead of processing them recursively.
eventQueue += event
} else {
try {
isProcessing = true
// Forward event to `onReceive` directly to avoid processing event asynchronously.
onReceive(event)
} catch {
case NonFatal(e) => onError(e)
} finally {
isProcessing = false
}
if (eventQueue.nonEmpty) {
post(eventQueue.remove(0))
}
// `DAGSchedulerEventProcessLoop` is guaranteed to process events sequentially in the main test
// thread similarly as it is done in production using the "dag-scheduler-event-loop".
// So we should buffer events for sequent processing later instead of executing them
// on thread calling post() (which might be the "dag-scheduler-message" thread for some
// events posted by the DAGScheduler itself)
eventQueue.put(event)
}

def runEvents(): Unit = {
var dagEvent = eventQueue.poll()
while (dagEvent != null) {
onReciveWithErrorHandler(dagEvent)
dagEvent = eventQueue.poll()
}
}

private def onReciveWithErrorHandler(event: DAGSchedulerEvent): Unit = {
try {
// Forward event to `onReceive` directly to avoid processing event asynchronously.
onReceive(event)
} catch {
case NonFatal(e) => onError(e)
}
}

Expand Down Expand Up @@ -306,7 +309,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
var broadcastManager: BroadcastManager = null
var securityMgr: SecurityManager = null
var scheduler: DAGScheduler = null
var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null
var dagEventProcessLoopTester: DAGSchedulerEventProcessLoopTester = null

/**
* Set of cache locations to return from our mock BlockManagerMaster.
Expand Down Expand Up @@ -479,6 +482,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
// Ensure the initialization of various components
sc
dagEventProcessLoopTester.post(event)
dagEventProcessLoopTester.runEvents()
}

/**
Expand Down Expand Up @@ -1190,11 +1194,12 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
private def completeNextStageWithFetchFailure(
stageId: Int,
attemptIdx: Int,
shuffleDep: ShuffleDependency[_, _, _]): Unit = {
shuffleDep: ShuffleDependency[_, _, _],
srcHost: String = "hostA"): Unit = {
val stageAttempt = taskSets.last
checkStageId(stageId, attemptIdx, stageAttempt)
complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map { case (task, idx) =>
(FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0L, 0, idx, "ignored"), null)
(FetchFailed(makeBlockManagerId(srcHost), shuffleDep.shuffleId, 0L, 0, idx, "ignored"), null)
}.toSeq)
}

Expand Down Expand Up @@ -2251,6 +2256,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
assert(completedStage === List(0, 1))

Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
dagEventProcessLoopTester.runEvents()
// map stage resubmitted
assert(scheduler.runningStages.size === 1)
val mapStage = scheduler.runningStages.head
Expand Down Expand Up @@ -2286,6 +2292,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
sc.listenerBus.waitUntilEmpty()

Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
dagEventProcessLoopTester.runEvents()
// map stage is running by resubmitted, result stage is waiting
// map tasks and the origin result task 1.0 are running
assert(scheduler.runningStages.size == 1, "Map stage should be running")
Expand Down Expand Up @@ -3125,6 +3132,95 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
assert(countSubmittedMapStageAttempts() === 2)
}

/**
* This function creates the following dependency graph:
*
* (determinate) (indeterminate)
* shuffleMapRdd0 shuffleMapRDD1
* \ /
* \ /
* finalRdd
*
* Both ShuffleMapRdds will be ShuffleMapStages with 2 partitions executed on
* hostA_exec and hostB_exec.
*/
def constructMixedDeterminateDependencies():
(ShuffleDependency[_, _, _], ShuffleDependency[_, _, _]) = {
val numPartitions = 2
val shuffleMapRdd0 = new MyRDD(sc, numPartitions, Nil, indeterminate = false)

val shuffleDep0 = new ShuffleDependency(shuffleMapRdd0, new HashPartitioner(2))
val shuffleId0 = shuffleDep0.shuffleId
val shuffleMapRdd1 =
new MyRDD(sc, numPartitions, Nil, tracker = mapOutputTracker, indeterminate = true)

val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2))
val shuffleId1 = shuffleDep1.shuffleId
val finalRdd =
new MyRDD(sc, numPartitions, List(shuffleDep0, shuffleDep1), tracker = mapOutputTracker)

submit(finalRdd, Array(0, 1))
val stageId0 = this.scheduler.shuffleIdToMapStage(shuffleId0).id

// Finish the first shuffle map stage.
completeShuffleMapStageSuccessfully(0, 0, numPartitions, Seq("hostA", "hostB"))
completeShuffleMapStageSuccessfully(1, 0, numPartitions, Seq("hostA", "hostB"))
assert(mapOutputTracker.findMissingPartitions(0) === Some(Seq.empty))
assert(mapOutputTracker.findMissingPartitions(1) === Some(Seq.empty))

(shuffleDep0, shuffleDep1)
}

test("SPARK-51272: re-submit of an indeterminate stage whithout partial result can succeed") {
val shuffleDeps = constructMixedDeterminateDependencies()
val resultStage = scheduler.stageIdToStage(2).asInstanceOf[ResultStage]

// the fetch failure is from the determinate shuffle map stage but this leads to
// executor lost and removing the shuffle files generated by the indeterminate stage too
completeNextStageWithFetchFailure(resultStage.id, 0, shuffleDeps._1, "hostA")

Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
dagEventProcessLoopTester.runEvents()
assert(scheduler.runningStages.size === 2)
assert(scheduler.runningStages.forall(_.isInstanceOf[ShuffleMapStage]))

completeShuffleMapStageSuccessfully(0, 1, 2, Seq("hostA", "hostB"))
completeShuffleMapStageSuccessfully(1, 1, 2, Seq("hostA", "hostB"))
assert(scheduler.runningStages.size === 1)
assert(scheduler.runningStages.head === resultStage)
assert(resultStage.latestInfo.failureReason.isEmpty)

completeNextResultStageWithSuccess(resultStage.id, 1)
}

test("SPARK-51272: re-submit of an indeterminate stage whith partial result will fail") {
val shuffleDeps = constructMixedDeterminateDependencies()
val resultStage = scheduler.stageIdToStage(2).asInstanceOf[ResultStage]

runEvent(makeCompletionEvent(taskSets(2).tasks(0), Success, 42))
// the fetch failure is from the determinate shuffle map stage but this leads to
// executor lost and removing the shuffle files generated by the indeterminate stage too
runEvent(makeCompletionEvent(
taskSets(2).tasks(1),
FetchFailed(makeBlockManagerId("hostA"), shuffleDeps._1.shuffleId, 0L, 0, 0, "ignored"),
null))

dagEventProcessLoopTester.runEvents()
// resubmission has not yet happened, so job is still running
assert(scheduler.activeJobs.nonEmpty)
Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
dagEventProcessLoopTester.runEvents()

// all dependent jobs have been failed
assert(scheduler.runningStages.size === 0)
assert(scheduler.activeJobs.isEmpty)
assert(resultStage.latestInfo.failureReason.isDefined)
assert(resultStage.latestInfo.failureReason.get.
contains("A shuffle map stage with indeterminate output was failed and retried. " +
"However, Spark cannot rollback the ResultStage"))
assert(scheduler.activeJobs.isEmpty, "Aborting the stage aborts the job as well.")
}

private def constructIndeterminateStageFetchFailed(): (Int, Int) = {
val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true)

Expand Down Expand Up @@ -4884,6 +4980,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
// wait resubmit
sc.listenerBus.waitUntilEmpty()
Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
dagEventProcessLoopTester.runEvents()

// stage0 retry
val stage0Retry = taskSets.filter(_.stageId == 1)
Expand Down Expand Up @@ -4984,6 +5081,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti

// the stages will now get resubmitted due to the failure
Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
dagEventProcessLoopTester.runEvents()

// parent map stage resubmitted
assert(scheduler.runningStages.size === 1)
Expand All @@ -5003,6 +5101,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
result = MapStatus(BlockManagerId("hostF-exec1", "hostF", 12345),
Array.fill[Long](2)(2), mapTaskId = taskIdCount)))
Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
dagEventProcessLoopTester.runEvents()

// The retries should succeed
sc.listenerBus.waitUntilEmpty()
Expand All @@ -5012,6 +5111,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
// This will add 3 new stages.
submit(reduceRdd, Array(0, 1))
Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
dagEventProcessLoopTester.runEvents()

// Only the last stage needs to execute, and those tasks - so completed stages should not
// change.
Expand Down