Skip to content
Closed
137 changes: 83 additions & 54 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,20 @@ private[spark] class DAGScheduler(
// `findMissingPartitions()` returns all partitions every time.
stage match {
case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
// already executed atleast once
if (sms.getNextAttemptId > 0) {
val stagesToRollback = collectSucceedingStages(sms)
rollBackStages(stagesToRollback)
// stages which cannot be rolled back were aborted which leads to removing the
// the dependant job(s) from the active jobs set
val numActiveJobsWithStageAfterRollback =
activeJobs.count(job => stagesToRollback.contains(job.finalStage))
if (numActiveJobsWithStageAfterRollback == 0) {
logInfo("All jobs depending on this indeterminate stage (" + stage + ") were " +
"aborted so this stage is not needed anymore.")
return
}
}
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
sms.shuffleDep.newShuffleMergeState()
case _ =>
Expand Down Expand Up @@ -2129,60 +2143,8 @@ private[spark] class DAGScheduler(
// guaranteed to be determinate, so the input data of the reducers will not change
// even if the map tasks are re-tried.
if (mapStage.isIndeterminate) {
// It's a little tricky to find all the succeeding stages of `mapStage`, because
// each stage only know its parents not children. Here we traverse the stages from
// the leaf nodes (the result stages of active jobs), and rollback all the stages
// in the stage chains that connect to the `mapStage`. To speed up the stage
// traversing, we collect the stages to rollback first. If a stage needs to
// rollback, all its succeeding stages need to rollback to.
val stagesToRollback = HashSet[Stage](mapStage)

def collectStagesToRollback(stageChain: List[Stage]): Unit = {
if (stagesToRollback.contains(stageChain.head)) {
stageChain.drop(1).foreach(s => stagesToRollback += s)
} else {
stageChain.head.parents.foreach { s =>
collectStagesToRollback(s :: stageChain)
}
}
}

def generateErrorMessage(stage: Stage): String = {
"A shuffle map stage with indeterminate output was failed and retried. " +
s"However, Spark cannot rollback the $stage to re-process the input data, " +
"and has to fail this job. Please eliminate the indeterminacy by " +
"checkpointing the RDD before repartition and try again."
}

activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil))

// The stages will be rolled back after checking
val rollingBackStages = HashSet[Stage](mapStage)
stagesToRollback.foreach {
case mapStage: ShuffleMapStage =>
val numMissingPartitions = mapStage.findMissingPartitions().length
if (numMissingPartitions < mapStage.numTasks) {
if (sc.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) {
val reason = "A shuffle map stage with indeterminate output was failed " +
"and retried. However, Spark can only do this while using the new " +
"shuffle block fetching protocol. Please check the config " +
"'spark.shuffle.useOldFetchProtocol', see more detail in " +
"SPARK-27665 and SPARK-25341."
abortStage(mapStage, reason, None)
} else {
rollingBackStages += mapStage
}
}

case resultStage: ResultStage if resultStage.activeJob.isDefined =>
val numMissingPartitions = resultStage.findMissingPartitions().length
if (numMissingPartitions < resultStage.numTasks) {
// TODO: support to rollback result tasks.
abortStage(resultStage, generateErrorMessage(resultStage), None)
}

case _ =>
}
val stagesToRollback = collectSucceedingStages(mapStage)
val rollingBackStages = rollBackStages(stagesToRollback)
logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output was failed, " +
log"we will roll back and rerun below stages which include itself and all its " +
log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}")
Expand Down Expand Up @@ -2346,6 +2308,73 @@ private[spark] class DAGScheduler(
}
}

private def collectSucceedingStages(mapStage: ShuffleMapStage): HashSet[Stage] = {
// TODO: perhaps materialize this if we are going to compute it often enough ?
// It's a little tricky to find all the succeeding stages of `mapStage`, because
// each stage only know its parents not children. Here we traverse the stages from
// the leaf nodes (the result stages of active jobs), and rollback all the stages
// in the stage chains that connect to the `mapStage`. To speed up the stage
// traversing, we collect the stages to rollback first. If a stage needs to
// rollback, all its succeeding stages need to rollback to.
val succeedingStages = HashSet[Stage](mapStage)

def collectSucceedingStagesInternal(stageChain: List[Stage]): Unit = {
if (succeedingStages.contains(stageChain.head)) {
stageChain.drop(1).foreach(s => succeedingStages += s)
} else {
stageChain.head.parents.foreach { s =>
collectSucceedingStagesInternal(s :: stageChain)
}
}
}
activeJobs.foreach(job => collectSucceedingStagesInternal(job.finalStage :: Nil))
succeedingStages
}

/**
* @param stagesToRollback stages to roll back
* @return Shuffle map stages which need and can be rolled back
*/
private def rollBackStages(stagesToRollback: HashSet[Stage]): HashSet[Stage] = {

def generateErrorMessage(stage: Stage): String = {
"A shuffle map stage with indeterminate output was failed and retried. " +
s"However, Spark cannot rollback the $stage to re-process the input data, " +
"and has to fail this job. Please eliminate the indeterminacy by " +
"checkpointing the RDD before repartition and try again."
}

// The stages will be rolled back after checking
val rollingBackStages = HashSet[Stage]()
stagesToRollback.foreach {
case mapStage: ShuffleMapStage =>
val numMissingPartitions = mapStage.findMissingPartitions().length
if (numMissingPartitions < mapStage.numTasks) {
if (sc.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) {
val reason = "A shuffle map stage with indeterminate output was failed " +
"and retried. However, Spark can only do this while using the new " +
"shuffle block fetching protocol. Please check the config " +
"'spark.shuffle.useOldFetchProtocol', see more detail in " +
"SPARK-27665 and SPARK-25341."
abortStage(mapStage, reason, None)
} else {
rollingBackStages += mapStage
}
}

case resultStage: ResultStage if resultStage.activeJob.isDefined =>
val numMissingPartitions = resultStage.findMissingPartitions().length
if (numMissingPartitions < resultStage.numTasks) {
// TODO: support to rollback result tasks.
abortStage(resultStage, generateErrorMessage(resultStage), None)
}

case _ =>
}

rollingBackStages
}

/**
* Whether executor is decommissioning or decommissioned.
* Return true when:
Expand Down
Loading