Skip to content
Closed
7 changes: 7 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener}


Expand Down Expand Up @@ -190,4 +191,10 @@ abstract class TaskContext extends Serializable {
*/
private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit

/**
* Record that this task has failed due to a fetch failure from a remote host. This allows
* fetch-failure handling to get triggered by the driver, regardless of intervening user-code.
*/
private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit

}
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._

private[spark] class TaskContextImpl(
Expand Down Expand Up @@ -56,6 +57,10 @@ private[spark] class TaskContextImpl(
// Whether the task has failed.
@volatile private var failed: Boolean = false

// If there was a fetch failure in the task, we store it here, to make sure user-code doesn't
// hide the exception. See SPARK-19276
@volatile private var _fetchFailedException: Option[FetchFailedException] = None

override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
onCompleteCallbacks += listener
this
Expand Down Expand Up @@ -126,4 +131,10 @@ private[spark] class TaskContextImpl(
taskMetrics.registerAccumulator(a)
}

private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = {
this._fetchFailedException = Option(fetchFailed)
}

private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException

}
33 changes: 28 additions & 5 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.executor

import java.io.{File, NotSerializableException}
import java.lang.Thread.UncaughtExceptionHandler
import java.lang.management.ManagementFactory
import java.net.{URI, URL}
import java.nio.ByteBuffer
Expand Down Expand Up @@ -52,7 +53,8 @@ private[spark] class Executor(
executorHostname: String,
env: SparkEnv,
userClassPath: Seq[URL] = Nil,
isLocal: Boolean = false)
isLocal: Boolean = false,
uncaughtExceptionHandler: UncaughtExceptionHandler = SparkUncaughtExceptionHandler)
extends Logging {

logInfo(s"Starting executor ID $executorId on host $executorHostname")
Expand All @@ -78,7 +80,7 @@ private[spark] class Executor(
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler)
Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler)
}

// Start worker thread pool
Expand Down Expand Up @@ -342,6 +344,14 @@ private[spark] class Executor(
}
}
}
task.context.fetchFailed.foreach { fetchFailure =>
// uh-oh. it appears the user code has caught the fetch-failure without throwing any
// other exceptions. Its *possible* this is what the user meant to do (though highly
// unlikely). So we will log an error and keep going.
logError(s"TID ${taskId} completed successfully though internally it encountered " +
s"unrecoverable fetch failures! Most likely this means user code is incorrectly " +
s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
}
val taskFinish = System.currentTimeMillis()
val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
Expand Down Expand Up @@ -402,8 +412,17 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

} catch {
case ffe: FetchFailedException =>
val reason = ffe.toTaskFailedReason
case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
val reason = task.context.fetchFailed.get.toTaskFailedReason
if (!t.isInstanceOf[FetchFailedException]) {
// there was a fetch failure in the task, but some user code wrapped that exception
// and threw something else. Regardless, we treat it as a fetch failure.
val fetchFailedCls = classOf[FetchFailedException].getName
logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " +
s"failed, but the ${fetchFailedCls} was hidden by another " +
s"exception. Spark is handling this like a fetch failure and ignoring the " +
s"other exception: $t")
}
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

Expand Down Expand Up @@ -455,13 +474,17 @@ private[spark] class Executor(
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
SparkUncaughtExceptionHandler.uncaughtException(t)
uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
}

} finally {
runningTasks.remove(taskId)
}
}

private def hasFetchFailure: Boolean = {
task != null && task.context != null && task.context.fetchFailed.isDefined
}
}

/**
Expand Down
9 changes: 3 additions & 6 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,14 @@

package org.apache.spark.scheduler

import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
import java.util.Properties

import scala.collection.mutable
import scala.collection.mutable.HashMap

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util._

/**
Expand Down Expand Up @@ -137,6 +132,8 @@ private[spark] abstract class Task[T](
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
// Though we unset the ThreadLocal here, the context member variable itself is still queried
// directly in the TaskRunner to check for FetchFailedExceptions.
TaskContext.unset()
}
}
Expand All @@ -156,7 +153,7 @@ private[spark] abstract class Task[T](
var epoch: Long = -1

// Task context, to be initialized in run().
@transient protected var context: TaskContextImpl = _
@transient var context: TaskContextImpl = _

// The actual Thread on which the task is running, if any. Initialized in run().
@volatile @transient private var taskThread: Thread = _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.shuffle

import org.apache.spark.{FetchFailed, TaskFailedReason}
import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason}
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils

Expand All @@ -26,6 +26,11 @@ import org.apache.spark.util.Utils
* back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage.
*
* Note that bmAddress can be null.
*
* To prevent user code from hiding this fetch failure, in the constructor we call
* [[TaskContext.setFetchFailed()]]. This means that you *must* throw this exception immediately
* after creating it -- you cannot create it, check some condition, and then decide to ignore it
* (or risk triggering any other exceptions). See SPARK-19276.
*/
private[spark] class FetchFailedException(
bmAddress: BlockManagerId,
Expand All @@ -45,6 +50,12 @@ private[spark] class FetchFailedException(
this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
}

// SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code
// which intercepts this exception (possibly wrapping it), the Executor can still tell there was
// a fetch failure, and send the correct error msg back to the driver. We wrap with an Option
// because the TaskContext is not defined in some test cases.
Option(TaskContext.get()).map(_.setFetchFailed(this))

def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
Utils.exceptionString(this))
}
Expand Down
139 changes: 134 additions & 5 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.executor

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.lang.Thread.UncaughtExceptionHandler
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.{CountDownLatch, TimeUnit}
Expand All @@ -27,7 +28,7 @@ import scala.concurrent.duration._

import org.mockito.ArgumentCaptor
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.{inOrder, when}
import org.mockito.Mockito.{inOrder, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.concurrent.Eventually
Expand All @@ -37,9 +38,12 @@ import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.memory.MemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.{FakeTask, TaskDescription}
import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.BlockManagerId

class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {

Expand Down Expand Up @@ -123,6 +127,75 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}
}

test("SPARK-19276: Handle FetchFailedExceptions that are hidden by user exceptions") {
val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
sc = new SparkContext(conf)
val serializer = SparkEnv.get.closureSerializer.newInstance()
val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size

// Submit a job where a fetch failure is thrown, but user code has a try/catch which hides
// the fetch failure. The executor should still tell the driver that the task failed due to a
// fetch failure, not a generic exception from user code.
val inputRDD = new FetchFailureThrowingRDD(sc)
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false)
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
val task = new ResultTask(
stageId = 1,
stageAttemptId = 0,
taskBinary = taskBinary,
partition = secondRDD.partitions(0),
locs = Seq(),
outputId = 0,
localProperties = new Properties(),
serializedTaskMetrics = serializedTaskMetrics
)

val serTask = serializer.serialize(task)
val taskDescription = createFakeTaskDescription(serTask)

val failReason = runTaskAndGetFailReason(taskDescription)
assert(failReason.isInstanceOf[FetchFailed])
}

test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
// when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
// may be a false positive. And we should call the uncaught exception handler.
val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
sc = new SparkContext(conf)
val serializer = SparkEnv.get.closureSerializer.newInstance()
val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size

// Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat
// the fetch failure as a false positive, and just do normal OOM handling.
val inputRDD = new FetchFailureThrowingRDD(sc)
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true)
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
val task = new ResultTask(
stageId = 1,
stageAttemptId = 0,
taskBinary = taskBinary,
partition = secondRDD.partitions(0),
locs = Seq(),
outputId = 0,
localProperties = new Properties(),
serializedTaskMetrics = serializedTaskMetrics
)

val serTask = serializer.serialize(task)
val taskDescription = createFakeTaskDescription(serTask)

val (failReason, uncaughtExceptionHandler) =
runTaskGetFailReasonAndExceptionHandler(taskDescription)
// make sure the task failure just looks like a OOM, not a fetch failure
assert(failReason.isInstanceOf[ExceptionFailure])
val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
assert(exceptionCaptor.getAllValues.size === 1)
assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError])
}

test("Gracefully handle error in task deserialization") {
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
Expand Down Expand Up @@ -169,13 +242,20 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}

private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
}

private def runTaskGetFailReasonAndExceptionHandler(
taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = {
val mockBackend = mock[ExecutorBackend]
val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler]
var executor: Executor = null
try {
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
uncaughtExceptionHandler = mockUncaughtExceptionHandler)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockBackend, taskDescription)
eventually(timeout(5 seconds), interval(10 milliseconds)) {
eventually(timeout(5.seconds), interval(10.milliseconds)) {
assert(executor.numRunningTasks === 0)
}
} finally {
Expand All @@ -193,7 +273,56 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
// second update is more interesting
val failureData = statusCaptor.getAllValues.get(1)
SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
val failReason =
SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
(failReason, mockUncaughtExceptionHandler)
}
}

class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) {
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
new Iterator[Int] {
override def hasNext: Boolean = true
override def next(): Int = {
throw new FetchFailedException(
bmAddress = BlockManagerId("1", "hostA", 1234),
shuffleId = 0,
mapId = 0,
reduceId = 0,
message = "fake fetch failure"
)
}
}
}
override protected def getPartitions: Array[Partition] = {
Array(new SimplePartition)
}
}

class SimplePartition extends Partition {
override def index: Int = 0
}

class FetchFailureHidingRDD(
sc: SparkContext,
val input: FetchFailureThrowingRDD,
throwOOM: Boolean) extends RDD[Int](input) {
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
val inItr = input.compute(split, context)
try {
Iterator(inItr.size)
} catch {
case t: Throwable =>
if (throwOOM) {
throw new OutOfMemoryError("OOM while handling another exception")
} else {
throw new RuntimeException("User Exception that hides the original exception", t)
}
}
}

override protected def getPartitions: Array[Partition] = {
Array(new SimplePartition)
}
}

Expand Down
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ object MimaExcludes {
// [SPARK-14272][ML] Add logLikelihood in GaussianMixtureSummary
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.GaussianMixtureSummary.this"),

// [SPARK-19267] Fetch Failure handling robust to user error handling
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.setFetchFailed"),

// [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$10"),
Expand Down