Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli

_cleaner =
if (_conf.getBoolean("spark.cleaner.referenceTracking", true)) {
Some(new ContextCleaner(this))
val cleaner = new ContextCleaner(this)
cleaner.attachListener(dagScheduler)
Some(cleaner)
} else {
None
}
Expand Down
29 changes: 22 additions & 7 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class DAGScheduler(
blockManagerMaster: BlockManagerMaster,
env: SparkEnv,
clock: Clock = new SystemClock())
extends Logging {
extends Logging with CleanerListener {

def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
this(
Expand Down Expand Up @@ -483,9 +483,6 @@ class DAGScheduler(
logDebug("Removing running stage %d".format(stageId))
runningStages -= stage
}
for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) {
shuffleToMapStage.remove(k)
}
if (waitingStages.contains(stage)) {
logDebug("Removing stage %d from waiting set.".format(stageId))
waitingStages -= stage
Expand All @@ -496,9 +493,12 @@ class DAGScheduler(
}
}
// data structures based on StageId
stageIdToStage -= stageId
logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
// ShuffleMapStages aren't removed until the shuffle is cleaned
if (stage.isInstanceOf[ResultStage]) {
stageIdToStage -= stageId
logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}
}

jobSet -= job.jobId
Expand Down Expand Up @@ -1436,6 +1436,21 @@ class DAGScheduler(
taskScheduler.stop()
}

/**
* Called by the context cleaner when a shuffle is removed
* @param shuffleId
*/
override def shuffleCleaned(shuffleId: Int): Unit = {
val stageOpt = shuffleToMapStage.remove(shuffleId)
stageOpt.foreach { stage => stageIdToStage -= stage.id}
}

// These are all called by the context cleaner but we don't need them
override def accumCleaned(accId: Long): Unit = {}
override def broadcastCleaned(broadcastId: Long): Unit = {}
override def checkpointCleaned(rddId: Long): Unit = {}
override def rddCleaned(rddId: Int): Unit = {}

// Start the event thread and register the metrics source at the end of the constructor
env.metricsSystem.registerSource(metricsSource)
eventProcessLoop.start()
Expand Down
209 changes: 206 additions & 3 deletions core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.scheduler

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map}
import scala.language.reflectiveCalls
import scala.util.control.NonFatal
Expand Down Expand Up @@ -163,14 +164,17 @@ class DAGSchedulerSuite
cancelledStages.clear()
cacheLocations.clear()
results.clear()
mapOutputTracker = new MapOutputTrackerMaster(conf)
mapOutputTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
scheduler = new DAGScheduler(
sc,
taskScheduler,
sc.listenerBus,
mapOutputTracker,
blockManagerMaster,
sc.env)
// this is normally done in the SparkContext creation, but since we are ignoring the
// scheduler in the SparkContext and creating our own, need to re-register here
sc.cleaner.foreach{_.attachListener(scheduler)}
dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler)
}

Expand Down Expand Up @@ -695,6 +699,199 @@ class DAGSchedulerSuite
assertDataStructuresEmpty()
}


// Helper function to validate state when creating tests for task failures
def checkStageId(stageId: Int, attempt: Int, stageAttempt: TaskSet) {
assert(stageAttempt.stageId === stageId,
s": expected stage $stageId, instead was ${stageAttempt.stageId}")
assert(stageAttempt.stageAttemptId == attempt,
s": expected stage attempt $attempt, instead was ${stageAttempt.stageAttemptId}")
}

def makeCompletions(stageAttempt: TaskSet, reduceParts: Int): Seq[(Success.type, MapStatus)] = {
stageAttempt.tasks.zipWithIndex.map { case (task, idx) =>
(Success, makeMapStatus("host" + ('A' + idx).toChar, reduceParts))
}.toSeq
}

def setupStageAbortTest(sc: SparkContext) {
sc.listenerBus.addListener(new EndListener())
ended = false
jobResult = null
}

// Create a new Listener to confirm that the listenerBus sees the JobEnd message
// when we abort the stage. This message will also be consumed by the EventLoggingListener
// so this will propagate up to the user.
var ended = false
var jobResult : JobResult = null

class EndListener extends SparkListener {
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
jobResult = jobEnd.jobResult
ended = true
}
}

// Helper functions to extract commonly used code in Fetch Failure test cases
/**
* Common code to get the next stage attempt, confirm it's the one we expect, and complete it
* succesfullly.
*
* @param stageId - The current stageId
* @param attemptIdx - The current attempt count
* @param numShufflePartitions - The number of partitions in the next stage
*/
def completeNextShuffleMapSuccesfully(
stageId: Int,
attemptIdx: Int,
numShufflePartitions: Int): Unit = {
val stageAttempt = taskSets.last
checkStageId(stageId, attemptIdx, stageAttempt)
complete(stageAttempt, makeCompletions(stageAttempt, numShufflePartitions))
}

/**
* Common code to get the next stage attempt, confirm it's the one we expect, and complete it
* with all FetchFailure.
*
* @param stageId - The current stageId
* @param attemptIdx - The current attempt count
* @param shuffleDep - The shuffle dependency of the stage with a fetch failure
*/
def completeNextStageWithFetchFailure(
stageId: Int,
attemptIdx: Int,
shuffleDep: ShuffleDependency[_, _, _]): Unit = {
val stageAttempt = taskSets.last
checkStageId(stageId, attemptIdx, stageAttempt)

complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map{ case (task, idx) =>
(FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0, idx, "ignored"), null)
}.toSeq)
}

/**
* Common code to get the next result stage attempt, confirm it's the one we expect, and
* complete it with a success where we return 42.
*
* @param stageId - The current stageId
* @param attemptIdx - The current attempt count
*/
def completeNextResultStageWithSuccess (
stageId: Int,
attemptIdx: Int,
resultFunc: Int => Int = _ => 42): Unit = {
val stageAttempt = taskSets.last
checkStageId(stageId, attemptIdx, stageAttempt)
assert(scheduler.stageIdToStage(stageId).isInstanceOf[ResultStage])
complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map { case (_, idx) =>
(Success, resultFunc(idx))
}.toSeq)
}



test("shuffle fetch failure in a reused shuffle dependency") {
// Run the first job successfully, which creates one shuffle dependency

val jobIdToStageIds = new mutable.HashMap[Int, Set[Int]]()
val listener = new SparkListener {
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
jobIdToStageIds(jobStart.jobId) = jobStart.stageIds.toSet
}
}
sc.addSparkListener(listener)

val shuffleMapRdd = new MyRDD(sc, 2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
submit(reduceRdd, Array(0, 1))
sc.listenerBus.waitUntilEmpty(1000)
assert(jobIdToStageIds(0) === Set(0, 1))

completeNextShuffleMapSuccesfully(0, 0, 2)
completeNextResultStageWithSuccess(1, 0)
assert(results === Map(0 -> 42, 1 -> 42))
assertDataStructuresEmpty()

// submit another job w/ the shared dependency, and have a fetch failure
val reduce2 = new MyRDD(sc, 2, List(shuffleDep))
submit(reduce2, Array(0, 1))
sc.listenerBus.waitUntilEmpty(1000)
assert(jobIdToStageIds(1) === Set(0, 2))
completeNextStageWithFetchFailure(2, 0, shuffleDep)
scheduler.resubmitFailedStages()
completeNextShuffleMapSuccesfully(0, 1, 2)
completeNextResultStageWithSuccess(2, 1, idx => idx + 1234)
assert(results === Map(0 -> 1234, 1 -> 1235))

assertDataStructuresEmpty()
assertEmptyAfterContextCleaner()
}


test("reused dependency with long lineage") {
val jobIdToStageIds = new mutable.HashMap[Int, Set[Int]]()
val listener = new SparkListener {
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
jobIdToStageIds(jobStart.jobId) = jobStart.stageIds.toSet
}
}
sc.addSparkListener(listener)

val shuffleMapRdd = new MyRDD(sc, 2, Nil)
val shuffleDep1 = new ShuffleDependency(shuffleMapRdd, null)
val reduceRdd = new MyRDD(sc, 4, List(shuffleDep1))
val shuffleDep2 = new ShuffleDependency(reduceRdd, null)
val reduceRdd2 = new MyRDD(sc, 6, List(shuffleDep2))
val shuffleDep3 = new ShuffleDependency(reduceRdd2, null)
val reduceRdd3 = new MyRDD(sc, 8, List(shuffleDep3))
submit(reduceRdd3, (0 until 8).toArray)
sc.listenerBus.waitUntilEmpty(1000)
assert(jobIdToStageIds(0) === Set(0, 1, 2, 3))

completeNextShuffleMapSuccesfully(0, 0, 4)
completeNextShuffleMapSuccesfully(1, 0, 6)
completeNextShuffleMapSuccesfully(2, 0, 8)
completeNextResultStageWithSuccess(3, 0)
assert(results === (0 until 8).map{_ -> 42}.toMap)
results.clear()
assertDataStructuresEmpty()

// submit another job w/ the shared dependency, and have a fetch failure
val reduce4 = new MyRDD(sc, 2, List(shuffleDep3))
submit(reduce4, Array(0, 1))
sc.listenerBus.waitUntilEmpty(1000)
assert(jobIdToStageIds(1) === Set(0, 1, 2, 4))
completeNextStageWithFetchFailure(4, 0, shuffleDep3)
scheduler.resubmitFailedStages()
completeNextShuffleMapSuccesfully(0, 1, 4)
completeNextShuffleMapSuccesfully(1, 1, 6)
completeNextShuffleMapSuccesfully(2, 1, 8)
completeNextResultStageWithSuccess(4, 1, idx => idx + 1234)
assert(results === Map(0 -> 1234, 1 -> 1235))
results.clear()
assertDataStructuresEmpty()
assertEmptyAfterContextCleaner()

// now try submitting again, after we've cleaned out the shuffle data. Should be fine,
// we just need to rerun everything

val reduce5 = new MyRDD(sc, 8, List(shuffleDep3))
submit(reduce5, (0 until 8).toArray)
sc.listenerBus.waitUntilEmpty(1000)
assert(jobIdToStageIds(2) === Set(5, 6, 7, 8)) // new stages this time
completeNextShuffleMapSuccesfully(5, 0, 4)
completeNextShuffleMapSuccesfully(6, 0, 6)
completeNextShuffleMapSuccesfully(7, 0, 8)
completeNextResultStageWithSuccess(8, 0, idx => idx + 4321)
assert(results === (0 until 8).map{idx => idx -> (idx + 4321)}.toMap)

assertDataStructuresEmpty()
assertEmptyAfterContextCleaner()
}

/**
* Makes sure that failures of stage used by multiple jobs are correctly handled.
*
Expand Down Expand Up @@ -1012,13 +1209,19 @@ class DAGSchedulerSuite
assert(scheduler.failedStages.isEmpty)
assert(scheduler.jobIdToActiveJob.isEmpty)
assert(scheduler.jobIdToStageIds.isEmpty)
assert(scheduler.stageIdToStage.isEmpty)
assert(scheduler.runningStages.isEmpty)
assert(scheduler.shuffleToMapStage.isEmpty)
assert(scheduler.waitingStages.isEmpty)
assert(scheduler.outputCommitCoordinator.isEmpty)
}

private def assertEmptyAfterContextCleaner(): Unit = {
scheduler.shuffleToMapStage.foreach { case (shuffleId, _) =>
sc.cleaner.get.doCleanupShuffle(shuffleId, blocking = true)
}
assert(scheduler.stageIdToStage.isEmpty)
assert(scheduler.shuffleToMapStage.isEmpty)
}

// Nothing in this test should break if the task info's fields are null, but
// OutputCommitCoordinator requires the task info itself to not be null.
private def createFakeTaskInfo(): TaskInfo = {
Expand Down