Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ private[spark] class TaskSetManager(
private[scheduler] var emittedTaskSizeWarning = false

/** Add a task to all the pending-task lists that it should be on. */
private def addPendingTask(index: Int) {
private[spark] def addPendingTask(index: Int) {
for (loc <- tasks(index).preferredLocations) {
loc match {
case e: ExecutorCacheTaskLocation =>
Expand Down Expand Up @@ -832,15 +832,6 @@ private[spark] class TaskSetManager(

sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info)

if (successful(index)) {
logInfo(s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, but the task will not" +
s" be re-executed (either because the task failed with a shuffle data fetch failure," +
s" so the previous stage needs to be re-run, or because a different copy of the task" +
s" has already succeeded).")
} else {
addPendingTask(index)
}

if (!isZombie && reason.countTowardsTaskFailures) {
taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask(
info.host, info.executorId, index))
Expand All @@ -854,6 +845,16 @@ private[spark] class TaskSetManager(
return
}
}

if (successful(index)) {
logInfo(s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, but the task will not" +
s" be re-executed (either because the task failed with a shuffle data fetch failure," +
s" so the previous stage needs to be re-run, or because a different copy of the task" +
s" has already succeeded).")
} else {
addPendingTask(index)
}

maybeFinishTaskSet()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.mockito.Matchers.{any, anyInt, anyString}
import org.mockito.Mockito.{mock, never, spy, verify, when}
import org.mockito.Mockito.{mock, never, spy, times, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer

Expand Down Expand Up @@ -1172,6 +1172,48 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
assert(blacklistTracker.isNodeBlacklisted("host1"))
}

test("update blacklist before adding pending task to avoid race condition") {
// When a task fails, it should apply the blacklist policy prior to
// retrying the task otherwise there's a race condition where run on
// the same executor that it was intended to be black listed from.
val conf = new SparkConf().
set(config.BLACKLIST_ENABLED, true)

// Create a task with two executors.
sc = new SparkContext("local", "test", conf)
val exec = "executor1"
val host = "host1"
val exec2 = "executor2"
val host2 = "host2"
sched = new FakeTaskScheduler(sc, (exec, host), (exec2, host2))
val taskSet = FakeTask.createTaskSet(1)

val clock = new ManualClock
val mockListenerBus = mock(classOf[LiveListenerBus])
val blacklistTracker = new BlacklistTracker(mockListenerBus, conf, None, clock)
val taskSetManager = new TaskSetManager(sched, taskSet, 1, Some(blacklistTracker))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we using SystemClock for taskSetManager?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems all the tests in this file are using ManualClock so was following convention here. This test doesn't validate anything specifically dependent on the clock/time.

val taskSetManagerSpy = spy(taskSetManager)

val taskDesc = taskSetManagerSpy.resourceOffer(exec, host, TaskLocality.ANY)

// Assert the task has been black listed on the executor it was last executed on.
when(taskSetManagerSpy.addPendingTask(anyInt())).thenAnswer(
new Answer[Unit] {
override def answer(invocationOnMock: InvocationOnMock): Unit = {
val task = invocationOnMock.getArgumentAt(0, classOf[Int])
assert(taskSetManager.taskSetBlacklistHelperOpt.get.
isExecutorBlacklistedForTask(exec, task))
}
}
)

// Simulate a fake exception
val e = new ExceptionFailure("a", "b", Array(), "c", None)
taskSetManagerSpy.handleFailedTask(taskDesc.get.taskId, TaskState.FAILED, e)

verify(taskSetManagerSpy, times(1)).addPendingTask(anyInt())
}

private def createTaskResult(
id: Int,
accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = {
Expand Down