Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ private[spark] class TaskSetManager(
private[scheduler] var emittedTaskSizeWarning = false

/** Add a task to all the pending-task lists that it should be on. */
private def addPendingTask(index: Int) {
private[spark] def addPendingTask(index: Int) {
for (loc <- tasks(index).preferredLocations) {
loc match {
case e: ExecutorCacheTaskLocation =>
Expand Down Expand Up @@ -832,15 +832,6 @@ private[spark] class TaskSetManager(

sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info)

if (successful(index)) {
logInfo(s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, but the task will not" +
s" be re-executed (either because the task failed with a shuffle data fetch failure," +
s" so the previous stage needs to be re-run, or because a different copy of the task" +
s" has already succeeded).")
} else {
addPendingTask(index)
}

if (!isZombie && reason.countTowardsTaskFailures) {
taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask(
info.host, info.executorId, index))
Expand All @@ -854,6 +845,16 @@ private[spark] class TaskSetManager(
return
}
}

if (successful(index)) {
logInfo(s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, but the task will not" +
s" be re-executed (either because the task failed with a shuffle data fetch failure," +
s" so the previous stage needs to be re-run, or because a different copy of the task" +
s" has already succeeded).")
} else {
addPendingTask(index)
}

maybeFinishTaskSet()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.mockito.Matchers.{any, anyInt, anyString}
import org.mockito.Mockito.{mock, never, spy, verify, when}
import org.mockito.Mockito.{mock, never, spy, times, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer

Expand Down Expand Up @@ -1172,6 +1172,50 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
assert(blacklistTracker.isNodeBlacklisted("host1"))
}

test("update blacklist before adding pending task to avoid race condition") {
// When a task fails, it should apply the blacklist policy prior to
// retrying the task otherwise there's a race condition where run on
// the same executor that it was intended to be black listed from.
val conf = new SparkConf().
set(config.BLACKLIST_ENABLED, true).
set(config.MAX_TASK_ATTEMPTS_PER_EXECUTOR, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value of config.MAX_TASK_ATTEMPTS_PER_EXECUTOR is 1, so we don't have to set it here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I added it to make the test code (configuration) inputs more explicit, but I can remove if it's a default unlikely to change.


// Create a task with two executors.
sc = new SparkContext("local", "test", conf)
val exec = "executor1"
val host = "host1"
val exec2 = "executor2"
val host2 = "host2"
sched = new FakeTaskScheduler(sc, (exec, host), (exec2, host2))
val taskSet = FakeTask.createTaskSet(1)

val clock = new ManualClock
val mockListenerBus = mock(classOf[LiveListenerBus])
val blacklistTracker = new BlacklistTracker(mockListenerBus, conf, None, clock)
val taskSetManager = new TaskSetManager(sched, taskSet, 1, Some(blacklistTracker))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we using SystemClock for taskSetManager?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems all the tests in this file are using ManualClock so was following convention here. This test doesn't validate anything specifically dependent on the clock/time.

val taskSetManagerSpy = spy(taskSetManager)

val taskDesc = taskSetManagerSpy.resourceOffer(exec, host, TaskLocality.ANY)

// Assert the task has been black listed on the executor it was last executed on.
when(taskSetManagerSpy.addPendingTask(anyInt())).thenAnswer(
new Answer[Unit] {
override def answer(invocationOnMock: InvocationOnMock): Unit = {
val task = invocationOnMock.getArgumentAt(0, classOf[Int])
assert(taskSetManager.taskSetBlacklistHelperOpt.get.
isExecutorBlacklistedForTask(exec, task))
}
}
)

// Simulate an out of memory error
val e = new OutOfMemoryError
taskSetManagerSpy.handleFailedTask(
taskDesc.get.taskId, TaskState.FAILED, new ExceptionFailure(e, Seq()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ExceptionFailure is a case class, so you may use:

val e = ExceptionFailure("a", "b", Array(), "c", None)
taskSetManagerSpy.handleFailedTask(taskDesc.get.taskId, TaskState.FAILED, endReason)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay


verify(taskSetManagerSpy, times(1)).addPendingTask(anyInt())
}

private def createTaskResult(
id: Int,
accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = {
Expand Down