Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ private[spark] class TaskSchedulerImpl(
// on this class.
private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]

// Protected by `this`
private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
val taskIdToExecutorId = new HashMap[Long, String]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM
* all tasks.
*/
def badHostBackend(): Unit = {
val task = backend.beginTask()
val host = backend.executorIdToExecutor(task.executorId).host
val (taskDescription, _) = backend.beginTask()
val host = backend.executorIdToExecutor(taskDescription.executorId).host
if (host == badHost) {
backend.taskFailed(task, new RuntimeException("I'm a bad host!"))
backend.taskFailed(taskDescription, new RuntimeException("I'm a bad host!"))
} else {
backend.taskSuccess(task, 42)
backend.taskSuccess(taskDescription, 42)
}
}

Expand All @@ -48,7 +48,6 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM
val duration = Duration(1, SECONDS)
Await.ready(jobFuture, duration)
}
assert(results.isEmpty)
assertDataStructuresEmpty(noFailure = false)
}

Expand All @@ -68,7 +67,6 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM
val duration = Duration(3, SECONDS)
Await.ready(jobFuture, duration)
}
assert(results.isEmpty)
assertDataStructuresEmpty(noFailure = false)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,26 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa
}
}

/**
* A map from partition -> results for all tasks of a job when you call this test framework's
* [[submit]] method. Two important considerations:
*
* 1. If there is a job failure, results may or may not be empty. If any tasks succeed before
* the job has failed, they will get included in `results`. Instead, check for job failure by
* checking [[failure]]. (Also see [[assertDataStructuresEmpty()]])
*
* 2. This only gets cleared between tests. So you'll need to do special handling if you submit
* more than one job in one test.
*/
val results = new HashMap[Int, Any]()

/**
* If a call to [[submit]] results in a job failure, this will hold the exception, else it will
* be null.
*
* As with [[results]], this only gets cleared between tests, so care must be taken if you are
* submitting more than one job in one test.
*/
var failure: Throwable = _

/**
Expand All @@ -113,6 +132,11 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa
}
}

/**
* Helper to run a few common asserts after a job has completed, in particular some internal
* datastructures for bookkeeping. This only does a very minimal check for whether the job
* failed or succeeded -- often you will want extra asserts on [[results]] or [[failure]].
*/
protected def assertDataStructuresEmpty(noFailure: Boolean = true): Unit = {
if (noFailure) {
if (failure != null) {
Expand All @@ -133,6 +157,8 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa
// when the job succeeds
assert(taskScheduler.runningTaskSets.isEmpty)
assert(!backend.hasTasks)
} else {
assert(failure != null)
}
assert(scheduler.activeJobs.isEmpty)
}
Expand Down Expand Up @@ -217,10 +243,10 @@ private[spark] abstract class MockBackend(
* Test backends should call this to get a task that has been assigned to them by the scheduler.
* Each task should be responded to with either [[taskSuccess]] or [[taskFailed]].
*/
def beginTask(): TaskDescription = {
def beginTask(): (TaskDescription, Task[_]) = {
synchronized {
val toRun = assignedTasksWaitingToRun.remove(assignedTasksWaitingToRun.size - 1)
runningTasks += toRun
runningTasks += toRun._1.taskId
toRun
}
}
Expand Down Expand Up @@ -255,7 +281,7 @@ private[spark] abstract class MockBackend(
taskScheduler.statusUpdate(task.taskId, state, resultBytes)
if (TaskState.isFinished(state)) {
synchronized {
runningTasks -= task
runningTasks -= task.taskId
executorIdToExecutor(task.executorId).freeCores += taskScheduler.CPUS_PER_TASK
freeCores += taskScheduler.CPUS_PER_TASK
}
Expand All @@ -264,9 +290,9 @@ private[spark] abstract class MockBackend(
}

// protected by this
private val assignedTasksWaitingToRun = new ArrayBuffer[TaskDescription](10000)
private val assignedTasksWaitingToRun = new ArrayBuffer[(TaskDescription, Task[_])](10000)
// protected by this
private val runningTasks = ArrayBuffer[TaskDescription]()
private val runningTasks = HashSet[Long]()

def hasTasks: Boolean = synchronized {
assignedTasksWaitingToRun.nonEmpty || runningTasks.nonEmpty
Expand Down Expand Up @@ -307,10 +333,19 @@ private[spark] abstract class MockBackend(
*/
override def reviveOffers(): Unit = {
val offers: Seq[WorkerOffer] = generateOffers()
val newTasks = taskScheduler.resourceOffers(offers).flatten
val newTaskDescriptions = taskScheduler.resourceOffers(offers).flatten
// get the task now, since that requires a lock on TaskSchedulerImpl, to prevent individual
// tests from introducing a race if they need it
val newTasks = taskScheduler.synchronized {
newTaskDescriptions.map { taskDescription =>
val taskSet = taskScheduler.taskIdToTaskSetManager(taskDescription.taskId).taskSet
val task = taskSet.tasks(taskDescription.index)
(taskDescription, task)
}
}
synchronized {
newTasks.foreach { task =>
executorIdToExecutor(task.executorId).freeCores -= taskScheduler.CPUS_PER_TASK
newTasks.foreach { case (taskDescription, _) =>
executorIdToExecutor(taskDescription.executorId).freeCores -= taskScheduler.CPUS_PER_TASK
}
freeCores -= newTasks.size * taskScheduler.CPUS_PER_TASK
assignedTasksWaitingToRun ++= newTasks
Expand Down Expand Up @@ -437,8 +472,8 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor
*/
testScheduler("super simple job") {
def runBackend(): Unit = {
val task = backend.beginTask()
backend.taskSuccess(task, 42)
val (taskDescripition, _) = backend.beginTask()
backend.taskSuccess(taskDescripition, 42)
}
withBackend(runBackend _) {
val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray)
Expand Down Expand Up @@ -473,9 +508,7 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor
val d = join(30, b, c)

def runBackend(): Unit = {
val taskDescription = backend.beginTask()
val taskSet = taskScheduler.taskIdToTaskSetManager(taskDescription.taskId).taskSet
val task = taskSet.tasks(taskDescription.index)
val (taskDescription, task) = backend.beginTask()

// make sure the required map output is available
task.stageId match {
Expand Down Expand Up @@ -515,9 +548,7 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor
val stageToAttempts = new HashMap[Int, HashSet[Int]]()

def runBackend(): Unit = {
val taskDescription = backend.beginTask()
val taskSet = taskScheduler.taskIdToTaskSetManager(taskDescription.taskId).taskSet
val task = taskSet.tasks(taskDescription.index)
val (taskDescription, task) = backend.beginTask()
stageToAttempts.getOrElseUpdate(task.stageId, new HashSet()) += task.stageAttemptId

// make sure the required map output is available
Expand Down Expand Up @@ -549,16 +580,15 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor

testScheduler("job failure after 4 attempts") {
def runBackend(): Unit = {
val task = backend.beginTask()
backend.taskFailed(task, new RuntimeException("test task failure"))
val (taskDescription, _) = backend.beginTask()
backend.taskFailed(taskDescription, new RuntimeException("test task failure"))
}
withBackend(runBackend _) {
val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray)
val duration = Duration(1, SECONDS)
Await.ready(jobFuture, duration)
failure.getMessage.contains("test task failure")
}
assert(results.isEmpty)
assertDataStructuresEmpty(noFailure = false)
}
}