Skip to content
Closed
7 changes: 7 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener}


Expand Down Expand Up @@ -190,4 +191,10 @@ abstract class TaskContext extends Serializable {
*/
private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit

/**
* Record that this task has failed due to a fetch failure from a remote host. This allows
* fetch-failure handling to get triggered by the driver, regardless of intervening user-code.
*/
private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit

}
7 changes: 7 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._

private[spark] class TaskContextImpl(
Expand Down Expand Up @@ -56,6 +57,8 @@ private[spark] class TaskContextImpl(
// Whether the task has failed.
@volatile private var failed: Boolean = false

var fetchFailed: Option[FetchFailedException] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@volatile private ?


override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
onCompleteCallbacks += listener
this
Expand Down Expand Up @@ -126,4 +129,8 @@ private[spark] class TaskContextImpl(
taskMetrics.registerAccumulator(a)
}

private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = {
this.fetchFailed = Some(fetchFailed)
}

}
17 changes: 17 additions & 0 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ private[spark] class Executor(

startDriverHeartbeater()

private[executor] def numRunningTasks: Int = runningTasks.size()

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
val tr = new TaskRunner(context, taskDescription)
runningTasks.put(taskDescription.taskId, tr)
Expand Down Expand Up @@ -340,6 +342,14 @@ private[spark] class Executor(
}
}
}
task.context.fetchFailed.foreach { fetchFailure =>
// uh-oh. it appears the user code has caught the fetch-failure without throwing any
// other exceptions. Its *possible* this is what the user meant to do (though highly
// unlikely). So we will log an error and keep going.
logError(s"TID ${taskId} completed successfully though internally it encountered " +
s"unrecoverable fetch failures! Most likely this means user code is incorrectly " +
s"swallowing Spark's internal exceptions", fetchFailure)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you be explicit about what exception is getting swallowed here? (i.e., "incorrectly swallowing Spark's internal FetchFailedException") -- to possibly simplify debugging/fixing this issue for a user who runs into it.

}
val taskFinish = System.currentTimeMillis()
val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
Expand Down Expand Up @@ -405,6 +415,13 @@ private[spark] class Executor(
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

case t: Throwable if task.context.fetchFailed.isDefined =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

task and task.context can be null in case exception is thrown before/while deserializing task or before task is run (or initialization of context in task.run fails).
In any of these cases, the if condition here will result in NPE, and needs to be fixed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, great point! sorry I missed that. I've also added a test case for this as well.

// tbere was a fetch failure in the task, but some user code wrapped that exception
// and threw something else. Regardless, we treat it as a fetch failure.
val reason = task.context.fetchFailed.get.toTaskFailedReason
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit: but does it make sense to store the taskFailedReason (rather than the actual exception) in the task context?

setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Probably log a similar message as above ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean the msg I added about "TID ${taskId} completed successfully though internally it encountered unrecoverable fetch failures!"? I wouldn't think we'd want to log anything special here. I'm trying to make this a "normal" code path. The user is allowed to allowed to do this. (sparksql already does.)

we could log a warning, but then this change should be accompanied by auditing the code and making sure we never do this ourselves.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, something along those lines ...
And I agree, we should not be doing this ourselves as well.


case _: TaskKilledException =>
logInfo(s"Executor killed $taskName (TID $taskId)")
setTaskFinishedAndClearInterruptStatus()
Expand Down
9 changes: 3 additions & 6 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,14 @@

package org.apache.spark.scheduler

import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
import java.util.Properties

import scala.collection.mutable
import scala.collection.mutable.HashMap

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util._

/**
Expand Down Expand Up @@ -137,6 +132,8 @@ private[spark] abstract class Task[T](
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
// though we unset the ThreadLocal here, the context itself is still queried directly
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "the context member variable" instead of just "the context" (took me a min to parse this)

// in the TaskRunner to check for FetchFailedExceptions
TaskContext.unset()
}
}
Expand All @@ -156,7 +153,7 @@ private[spark] abstract class Task[T](
var epoch: Long = -1

// Task context, to be initialized in run().
@transient protected var context: TaskContextImpl = _
@transient var context: TaskContextImpl = _

// The actual Thread on which the task is running, if any. Initialized in run().
@volatile @transient private var taskThread: Thread = _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.shuffle

import org.apache.spark.{FetchFailed, TaskFailedReason}
import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason}
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -45,6 +45,12 @@ private[spark] class FetchFailedException(
this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
}

// SPARK-19267. We set the fetch failure in the task context, so that even if there is user-code
// which intercepts this exception (possibly wrapping it), the Executor can still tell there was
// a fetch failure, and send the correct error msg back to the driver. The TaskContext won't be
// defined if this is run on the driver (just in test cases) -- we can safely ignore then.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This last sentence is confusing. A task that runs locally on the driver can still hit fetch failures right? Or are you saying the TaskContext will only be not defined in test cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I've reworded this. The issue is that we have test cases where the TaskContext isn't defined, and so we'd hit an NPE without the Option wrapper. But in general, the TaskContext should always be defined anytime we'd create a FetchFailure.

The alternative would be to track down the test cases w/out a TaskContext, and add one back.

Option(TaskContext.get()).map(_.setFetchFailed(this))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since creation of an Exception does not necessarily mean it should get thrown - we must explicitly add this expectation to the documentation/contract of FetchFailedException constructor - indicating that we expect it to be created only for it to be thrown immediately.
This should be fine since FetchFailedException is private[spark] right now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good point. I added to the docs, does it look OK?

I also considered making the call to TaskContext.setFetchFailed live outside of the constructor, so at each site it was created, it would have to be called -- but I thought that seemed more dangerous.


def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
Utils.exceptionString(this))
}
Expand Down
144 changes: 125 additions & 19 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,34 @@ import java.util.concurrent.CountDownLatch

import scala.collection.mutable.Map

import org.mockito.Matchers._
import org.mockito.Mockito.{mock, when}
import org.mockito.ArgumentCaptor
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.{inOrder, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.mock.MockitoSugar

import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.memory.MemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.{FakeTask, TaskDescription}
import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.BlockManagerId

class ExecutorSuite extends SparkFunSuite {
class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar {

test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") {
// mock some objects to make Executor.launchTask() happy
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
val mockEnv = mock(classOf[SparkEnv])
val mockRpcEnv = mock(classOf[RpcEnv])
val mockMetricsSystem = mock(classOf[MetricsSystem])
val mockMemoryManager = mock(classOf[MemoryManager])
val mockEnv = mock[SparkEnv]
val mockRpcEnv = mock[RpcEnv]
val mockMetricsSystem = mock[MetricsSystem]
val mockMemoryManager = mock[MemoryManager]
when(mockEnv.conf).thenReturn(conf)
when(mockEnv.serializer).thenReturn(serializer)
when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
Expand All @@ -55,16 +60,7 @@ class ExecutorSuite extends SparkFunSuite {
val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array()
val serializedTask = serializer.newInstance().serialize(
new FakeTask(0, 0, Nil, fakeTaskMetrics))
val taskDescription = new TaskDescription(
taskId = 0,
attemptNumber = 0,
executorId = "",
name = "",
index = 0,
addedFiles = Map[String, Long](),
addedJars = Map[String, Long](),
properties = new Properties,
serializedTask)
val taskDescription = fakeTaskDescription(serializedTask)

// we use latches to force the program to run in this order:
// +-----------------------------+---------------------------------------+
Expand All @@ -86,7 +82,7 @@ class ExecutorSuite extends SparkFunSuite {

val executorSuiteHelper = new ExecutorSuiteHelper

val mockExecutorBackend = mock(classOf[ExecutorBackend])
val mockExecutorBackend = mock[ExecutorBackend]
when(mockExecutorBackend.statusUpdate(any(), any(), any()))
.thenAnswer(new Answer[Unit] {
var firstTime = true
Expand Down Expand Up @@ -133,6 +129,116 @@ class ExecutorSuite extends SparkFunSuite {
}
}
}

test("SPARK-19276: Handle Fetch Failed for all intervening user code") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about "Handle FetchFailedExceptions that are hidden by user exceptions"?

val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
val sc = new SparkContext(conf)

val serializer = SparkEnv.get.closureSerializer.newInstance()
val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
val inputRDD = new FakeShuffleRDD(sc)
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD)
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
val task = new ResultTask(
stageId = 1,
stageAttemptId = 0,
taskBinary = taskBinary,
partition = secondRDD.partitions(0),
locs = Seq(),
outputId = 0,
localProperties = new Properties(),
serializedTaskMetrics = serializedTaskMetrics
)

val serTask = serializer.serialize(task)
val taskDescription = fakeTaskDescription(serTask)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: too many empty lines


val mockBackend = mock[ExecutorBackend]
var executor: Executor = null
try {
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
executor.launchTask(mockBackend, taskDescription)
val startTime = System.currentTimeMillis()
val maxTime = startTime + 5000
while (executor.numRunningTasks > 0 && System.currentTimeMillis() < maxTime) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd use eventually here, or at least System.nanoTime instead.

Thread.sleep(10)
}
val orderedMock = inOrder(mockBackend)
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
orderedMock.verify(mockBackend)
.statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture())
orderedMock.verify(mockBackend)
.statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture())
// first statusUpdate for RUNNING has empty data
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
// second update is more interesting
val failureData = statusCaptor.getAllValues.get(1)
val failReason = serializer.deserialize[TaskFailedReason](failureData)
assert(failReason.isInstanceOf[FetchFailed])
} finally {
if (executor != null) {
executor.stop()
}
}
}

private def fakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = {
new TaskDescription(
taskId = 0,
attemptNumber = 0,
executorId = "",
name = "",
index = 0,
addedFiles = Map[String, Long](),
addedJars = Map[String, Long](),
properties = new Properties,
serializedTask)
}

}

class FakeShuffleRDD(sc: SparkContext) extends RDD[Int](sc, Nil) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

about about FetchFailureThrowingShuffleRDD? (to make it obvious what the point of this is?)

override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
new Iterator[Int] {
override def hasNext: Boolean = true
override def next(): Int = {
throw new FetchFailedException(
bmAddress = BlockManagerId("1", "hostA", 1234),
shuffleId = 0,
mapId = 0,
reduceId = 0,
message = "fake fetch failure"
)
}
}
}
override protected def getPartitions: Array[Partition] = {
Array(new SimplePartition)
}
}

class SimplePartition extends Partition {
override def index: Int = 0
}

class FetchFailureHidingRDD(
sc: SparkContext,
val input: FakeShuffleRDD) extends RDD[Int](input) {
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
val inItr = input.compute(split, context)
try {
Iterator(inItr.size)
} catch {
case t: Throwable =>
throw new RuntimeException("User Exception that hides the original exception", t)
}
}

override protected def getPartitions: Array[Partition] = {
Array(new SimplePartition)
}
}

// Helps to test("SPARK-15963")
Expand Down