Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1261,18 +1261,20 @@ class DAGScheduler(
s"has failed the maximum allowable number of " +
s"times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. " +
s"Most recent failure reason: ${failureMessage}", None)
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
} else {
if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
}
failedStages += failedStage
failedStages += mapStage
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.scheduler

import java.util.Properties
import java.util.concurrent.atomic.AtomicBoolean

import scala.annotation.meta.param
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
Expand All @@ -31,7 +32,7 @@ import org.apache.spark._
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException}
import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, LongAccumulator, Utils}

Expand Down Expand Up @@ -2105,6 +2106,61 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
assert(scheduler.getShuffleDependencies(rddE) === Set(shuffleDepA, shuffleDepC))
}

test("SPARK-17644: After one stage is aborted for too many failed attempts, subsequent stages" +
"still behave correctly on fetch failures") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the JIRA number here? That helps with tracking tests in the future. So something like "[SPARK-17644] After one stage is aborted..."

// Runs a job that always encounters a fetch failure, so should eventually be aborted
def runJobWithPersistentFetchFailure: Unit = {
val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey()
val shuffleHandle =
rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle
rdd1.map {
case (x, _) if (x == 1) =>
throw new FetchFailedException(
BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test")
case (x, _) => x
}.count()
}

// Runs a job that encounters a single fetch failure but succeeds on the second attempt
def runJobWithTemporaryFetchFailure: Unit = {
object FailThisAttempt {
val _fail = new AtomicBoolean(true)
}
val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey()
val shuffleHandle =
rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle
rdd1.map {
case (x, _) if (x == 1) && FailThisAttempt._fail.getAndSet(false) =>
throw new FetchFailedException(
BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test")
}
}

failAfter(10.seconds) {
val e = intercept[SparkException] {
runJobWithPersistentFetchFailure
}
assert(e.getMessage.contains("org.apache.spark.shuffle.FetchFailedException"))
}

// Run a second job that will fail due to a fetch failure.
// This job will hang without the fix for SPARK-17644.
failAfter(10.seconds) {
val e = intercept[SparkException] {
runJobWithPersistentFetchFailure
}
assert(e.getMessage.contains("org.apache.spark.shuffle.FetchFailedException"))
}

failAfter(10.seconds) {
try {
runJobWithTemporaryFetchFailure
} catch {
case e: Throwable => fail("A job with one fetch failure should eventually succeed")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you don't need to specifically catch an exception and call fail, the test will automatically fail from an unhandled exception

}
}
}

/**
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID.
Expand Down