From b93c37f1b0dfd3f1293b7a3df8beacc2fec7a33d Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 18 Jan 2017 15:55:50 -0600 Subject: [PATCH 01/12] [SPARK-19276][CORE] Fetch Failure handling robust to user error handling Fault-tolerance in spark requires special handling of shuffle fetch failures. The Executor would catch FetchFailedException and send a special msg back to the driver. However, intervening user code could intercept that exception, and wrap it with something else. This even happens in SparkSQL. So rather than checking the exception directly, we'll store the fetch failure directly in the TaskContext, where users can't touch it. This includes a test case which failed before the fix. --- .../scala/org/apache/spark/TaskContext.scala | 7 + .../org/apache/spark/TaskContextImpl.scala | 7 + .../org/apache/spark/executor/Executor.scala | 17 +++ .../org/apache/spark/scheduler/Task.scala | 10 +- .../spark/shuffle/FetchFailedException.scala | 8 +- .../apache/spark/executor/ExecutorSuite.scala | 144 +++++++++++++++--- 6 files changed, 167 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 0fd777ed1282..f0867ecb16ea 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener} @@ -190,4 +191,10 @@ abstract class TaskContext extends Serializable { */ private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit + /** + * Record that this task has failed due to a fetch failure from a remote host. This allows + * fetch-failure handling to get triggered by the driver, regardless of intervening user-code. + */ + private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit + } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index c904e083911c..e5610c997791 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ private[spark] class TaskContextImpl( @@ -56,6 +57,8 @@ private[spark] class TaskContextImpl( // Whether the task has failed. @volatile private var failed: Boolean = false + var fetchFailed: Option[FetchFailedException] = None + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { onCompleteCallbacks += listener this @@ -126,4 +129,8 @@ private[spark] class TaskContextImpl( taskMetrics.registerAccumulator(a) } + private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = { + this.fetchFailed = Some(fetchFailed) + } + } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index db5d0d85ceb8..93fa0854546c 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -148,6 +148,8 @@ private[spark] class Executor( startDriverHeartbeater() + private[executor] def numRunningTasks: Int = runningTasks.size() + def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { val tr = new TaskRunner(context, taskDescription) runningTasks.put(taskDescription.taskId, tr) @@ -340,6 +342,14 @@ private[spark] class Executor( } } } + task.context.fetchFailed.foreach { fetchFailure => + // uh-oh. it appears the user code has caught the fetch-failure without throwing any + // other exceptions. Its *possible* this is what the user meant to do (though highly + // unlikely). So we will log an error and keep going. + logError(s"TID ${taskId} completed successfully though internally it encountered " + + s"unrecoverable fetch failures! Most likely this means user code is incorrectly " + + s"swallowing Spark's internal exceptions", fetchFailure) + } val taskFinish = System.currentTimeMillis() val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime @@ -405,6 +415,13 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + case t: Throwable if task.context.fetchFailed.isDefined => + // tbere was a fetch failure in the task, but some user code wrapped that exception + // and threw something else. Regardless, we treat it as a fetch failure. + val reason = task.context.fetchFailed.get.toTaskFailedReason + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + case _: TaskKilledException => logInfo(s"Executor killed $taskName (TID $taskId)") setTaskFinishedAndClearInterruptStatus() diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7b726d5659e9..3e353627bbf2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,19 +17,15 @@ package org.apache.spark.scheduler -import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer import java.util.Properties -import scala.collection.mutable -import scala.collection.mutable.HashMap - import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config.APP_CALLER_CONTEXT import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ /** @@ -137,6 +133,8 @@ private[spark] abstract class Task[T]( memoryManager.synchronized { memoryManager.notifyAll() } } } finally { + // though we unset the ThreadLocal here, the context itself is still queried directly + // in the TaskRunner to check for FetchFailedExceptions TaskContext.unset() } } @@ -156,7 +154,7 @@ private[spark] abstract class Task[T]( var epoch: Long = -1 // Task context, to be initialized in run(). - @transient protected var context: TaskContextImpl = _ + @transient var context: TaskContextImpl = _ // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 498c12e196ce..15de964e8d4f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import org.apache.spark.{FetchFailed, TaskFailedReason} +import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason} import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -45,6 +45,12 @@ private[spark] class FetchFailedException( this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) } + // SPARK-19267. We set the fetch failure in the task context, so that even if there is user-code + // which intercepts this exception (possibly wrapping it), the Executor can still tell there was + // a fetch failure, and send the correct error msg back to the driver. The TaskContext won't be + // defined if this is run on the driver (just in test cases) -- we can safely ignore then. + Option(TaskContext.get()).map(_.setFetchFailed(this)) + def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, Utils.exceptionString(this)) } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index f94baaa30d18..9d14cadeda64 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -23,29 +23,34 @@ import java.util.concurrent.CountDownLatch import scala.collection.mutable.Map -import org.mockito.Matchers._ -import org.mockito.Mockito.{mock, when} +import org.mockito.ArgumentCaptor +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.{inOrder, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.mock.MockitoSugar import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.memory.MemoryManager import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcEnv -import org.apache.spark.scheduler.{FakeTask, TaskDescription} +import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.storage.BlockManagerId -class ExecutorSuite extends SparkFunSuite { +class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") { // mock some objects to make Executor.launchTask() happy val conf = new SparkConf val serializer = new JavaSerializer(conf) - val mockEnv = mock(classOf[SparkEnv]) - val mockRpcEnv = mock(classOf[RpcEnv]) - val mockMetricsSystem = mock(classOf[MetricsSystem]) - val mockMemoryManager = mock(classOf[MemoryManager]) + val mockEnv = mock[SparkEnv] + val mockRpcEnv = mock[RpcEnv] + val mockMetricsSystem = mock[MetricsSystem] + val mockMemoryManager = mock[MemoryManager] when(mockEnv.conf).thenReturn(conf) when(mockEnv.serializer).thenReturn(serializer) when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) @@ -55,16 +60,7 @@ class ExecutorSuite extends SparkFunSuite { val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array() val serializedTask = serializer.newInstance().serialize( new FakeTask(0, 0, Nil, fakeTaskMetrics)) - val taskDescription = new TaskDescription( - taskId = 0, - attemptNumber = 0, - executorId = "", - name = "", - index = 0, - addedFiles = Map[String, Long](), - addedJars = Map[String, Long](), - properties = new Properties, - serializedTask) + val taskDescription = fakeTaskDescription(serializedTask) // we use latches to force the program to run in this order: // +-----------------------------+---------------------------------------+ @@ -86,7 +82,7 @@ class ExecutorSuite extends SparkFunSuite { val executorSuiteHelper = new ExecutorSuiteHelper - val mockExecutorBackend = mock(classOf[ExecutorBackend]) + val mockExecutorBackend = mock[ExecutorBackend] when(mockExecutorBackend.statusUpdate(any(), any(), any())) .thenAnswer(new Answer[Unit] { var firstTime = true @@ -133,6 +129,116 @@ class ExecutorSuite extends SparkFunSuite { } } } + + test("SPARK-19276: Handle Fetch Failed for all intervening user code") { + val conf = new SparkConf().setMaster("local").setAppName("executor suite test") + val sc = new SparkContext(conf) + + val serializer = SparkEnv.get.closureSerializer.newInstance() + val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size + val inputRDD = new FakeShuffleRDD(sc) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD) + val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) + val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() + val task = new ResultTask( + stageId = 1, + stageAttemptId = 0, + taskBinary = taskBinary, + partition = secondRDD.partitions(0), + locs = Seq(), + outputId = 0, + localProperties = new Properties(), + serializedTaskMetrics = serializedTaskMetrics + ) + + val serTask = serializer.serialize(task) + val taskDescription = fakeTaskDescription(serTask) + + + val mockBackend = mock[ExecutorBackend] + var executor: Executor = null + try { + executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true) + executor.launchTask(mockBackend, taskDescription) + val startTime = System.currentTimeMillis() + val maxTime = startTime + 5000 + while (executor.numRunningTasks > 0 && System.currentTimeMillis() < maxTime) { + Thread.sleep(10) + } + val orderedMock = inOrder(mockBackend) + val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + orderedMock.verify(mockBackend) + .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) + orderedMock.verify(mockBackend) + .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) + // first statusUpdate for RUNNING has empty data + assert(statusCaptor.getAllValues().get(0).remaining() === 0) + // second update is more interesting + val failureData = statusCaptor.getAllValues.get(1) + val failReason = serializer.deserialize[TaskFailedReason](failureData) + assert(failReason.isInstanceOf[FetchFailed]) + } finally { + if (executor != null) { + executor.stop() + } + } + } + + private def fakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = { + new TaskDescription( + taskId = 0, + attemptNumber = 0, + executorId = "", + name = "", + index = 0, + addedFiles = Map[String, Long](), + addedJars = Map[String, Long](), + properties = new Properties, + serializedTask) + } + +} + +class FakeShuffleRDD(sc: SparkContext) extends RDD[Int](sc, Nil) { + override def compute(split: Partition, context: TaskContext): Iterator[Int] = { + new Iterator[Int] { + override def hasNext: Boolean = true + override def next(): Int = { + throw new FetchFailedException( + bmAddress = BlockManagerId("1", "hostA", 1234), + shuffleId = 0, + mapId = 0, + reduceId = 0, + message = "fake fetch failure" + ) + } + } + } + override protected def getPartitions: Array[Partition] = { + Array(new SimplePartition) + } +} + +class SimplePartition extends Partition { + override def index: Int = 0 +} + +class FetchFailureHidingRDD( + sc: SparkContext, + val input: FakeShuffleRDD) extends RDD[Int](input) { + override def compute(split: Partition, context: TaskContext): Iterator[Int] = { + val inItr = input.compute(split, context) + try { + Iterator(inItr.size) + } catch { + case t: Throwable => + throw new RuntimeException("User Exception that hides the original exception", t) + } + } + + override protected def getPartitions: Array[Partition] = { + Array(new SimplePartition) + } } // Helps to test("SPARK-15963") From 9635980fca20d18b44fa5085996ad43cbf3f3bb5 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 19 Jan 2017 00:59:09 -0600 Subject: [PATCH 02/12] cleanup --- core/src/main/scala/org/apache/spark/scheduler/Task.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 3e353627bbf2..189cb7c849cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -25,7 +25,6 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config.APP_CALLER_CONTEXT import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ /** From 0a60aefb9edb6bea0390ac9d481c5318f3b0dff8 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 19 Jan 2017 08:42:21 -0600 Subject: [PATCH 03/12] fix mima --- project/MimaExcludes.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index e0ee00e6826a..df5f2490e137 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -46,7 +46,10 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.streaming.scheduler.StreamingListener.onStreamingStarted"), // [SPARK-19148][SQL] do not expose the external table concept in Catalog - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.createTable") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.createTable"), + + // [SPARK-19267] Fetch Failure handling robust to user error handling + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.setFetchFailed") ) // Exclude rules for 2.1.x From bbef893d49cfd5fa51467dff326c0d15218491aa Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 19 Jan 2017 11:13:26 -0600 Subject: [PATCH 04/12] review feedback --- .../org/apache/spark/TaskContextImpl.scala | 8 +- .../org/apache/spark/executor/Executor.scala | 6 +- .../spark/shuffle/FetchFailedException.scala | 7 +- .../apache/spark/executor/ExecutorSuite.scala | 128 ++++++++++++------ 4 files changed, 100 insertions(+), 49 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index e5610c997791..d9cf48dbccc5 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -57,7 +57,9 @@ private[spark] class TaskContextImpl( // Whether the task has failed. @volatile private var failed: Boolean = false - var fetchFailed: Option[FetchFailedException] = None + // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't + // hide the exception. See SPARK-19276 + @volatile private var _fetchFailed: Option[FetchFailedException] = None override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { onCompleteCallbacks += listener @@ -130,7 +132,9 @@ private[spark] class TaskContextImpl( } private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = { - this.fetchFailed = Some(fetchFailed) + this._fetchFailed = Some(fetchFailed) } + private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailed + } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 93fa0854546c..025d2d74704c 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -415,7 +415,7 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - case t: Throwable if task.context.fetchFailed.isDefined => + case t: Throwable if hasFetchFailure => // tbere was a fetch failure in the task, but some user code wrapped that exception // and threw something else. Regardless, we treat it as a fetch failure. val reason = task.context.fetchFailed.get.toTaskFailedReason @@ -477,6 +477,10 @@ private[spark] class Executor( runningTasks.remove(taskId) } } + + private def hasFetchFailure: Boolean = { + task != null && task.context != null && task.context.fetchFailed.isDefined + } } /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 15de964e8d4f..a94ee450ae43 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -26,6 +26,11 @@ import org.apache.spark.util.Utils * back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage. * * Note that bmAddress can be null. + * + * To prevent user code from hiding this fetch failure, in the constructor we call + * [[TaskContext.setFetchFailed()]]. This means that you *must* throw this exception immediately + * after creating it -- you cannot create it, check some condition, and then decide to ignore it + * (or risk triggering any other exceptions). See SPARK-19276. */ private[spark] class FetchFailedException( bmAddress: BlockManagerId, @@ -45,7 +50,7 @@ private[spark] class FetchFailedException( this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) } - // SPARK-19267. We set the fetch failure in the task context, so that even if there is user-code + // SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code // which intercepts this exception (possibly wrapping it), the Executor can still tell there was // a fetch failure, and send the correct error msg back to the driver. The TaskContext won't be // defined if this is run on the driver (just in test cases) -- we can safely ignore then. diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 9d14cadeda64..cd708eee6624 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.executor +import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.CountDownLatch +import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.Map @@ -47,19 +48,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug // mock some objects to make Executor.launchTask() happy val conf = new SparkConf val serializer = new JavaSerializer(conf) - val mockEnv = mock[SparkEnv] - val mockRpcEnv = mock[RpcEnv] - val mockMetricsSystem = mock[MetricsSystem] - val mockMemoryManager = mock[MemoryManager] - when(mockEnv.conf).thenReturn(conf) - when(mockEnv.serializer).thenReturn(serializer) - when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) - when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) - when(mockEnv.memoryManager).thenReturn(mockMemoryManager) - when(mockEnv.closureSerializer).thenReturn(serializer) - val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array() - val serializedTask = serializer.newInstance().serialize( - new FakeTask(0, 0, Nil, fakeTaskMetrics)) + val env = mockEnv(conf, serializer) + val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0)) val taskDescription = fakeTaskDescription(serializedTask) // we use latches to force the program to run in this order: @@ -98,8 +88,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val taskState = invocationOnMock.getArguments()(1).asInstanceOf[TaskState] executorSuiteHelper.taskState = taskState val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer] - executorSuiteHelper.testFailedReason - = serializer.newInstance().deserialize(taskEndReason) + executorSuiteHelper.testFailedReason = + serializer.newInstance().deserialize(taskEndReason) // let the main test thread check `taskState` and `testFailedReason` executorSuiteHelper.latch3.countDown() } @@ -108,16 +98,20 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug var executor: Executor = null try { - executor = new Executor("id", "localhost", mockEnv, userClassPath = Nil, isLocal = true) + executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true) // the task will be launched in a dedicated worker thread executor.launchTask(mockExecutorBackend, taskDescription) - executorSuiteHelper.latch1.await() + if (!executorSuiteHelper.latch1.await(5, TimeUnit.SECONDS)) { + fail("executor did not send first status update in time") + } // we know the task will be started, but not yet deserialized, because of the latches we // use in mockExecutorBackend. executor.killAllTasks(true) executorSuiteHelper.latch2.countDown() - executorSuiteHelper.latch3.await() + if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) { + fail("executor did not send second status update in time") + } // `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED` assert(executorSuiteHelper.testFailedReason === TaskKilled) @@ -155,35 +149,42 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val taskDescription = fakeTaskDescription(serTask) - val mockBackend = mock[ExecutorBackend] - var executor: Executor = null - try { - executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true) - executor.launchTask(mockBackend, taskDescription) - val startTime = System.currentTimeMillis() - val maxTime = startTime + 5000 - while (executor.numRunningTasks > 0 && System.currentTimeMillis() < maxTime) { - Thread.sleep(10) - } - val orderedMock = inOrder(mockBackend) - val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) - orderedMock.verify(mockBackend) - .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) - orderedMock.verify(mockBackend) - .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) - // first statusUpdate for RUNNING has empty data - assert(statusCaptor.getAllValues().get(0).remaining() === 0) - // second update is more interesting - val failureData = statusCaptor.getAllValues.get(1) - val failReason = serializer.deserialize[TaskFailedReason](failureData) - assert(failReason.isInstanceOf[FetchFailed]) - } finally { - if (executor != null) { - executor.stop() - } + val failReason = runTaskAndGetFailReason(taskDescription) + assert(failReason.isInstanceOf[FetchFailed]) + } + + test("Gracefully handle error in task deserialization") { + val conf = new SparkConf + val serializer = new JavaSerializer(conf) + val env = mockEnv(conf, serializer) + val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask) + val taskDescription = fakeTaskDescription(serializedTask) + + val failReason = runTaskAndGetFailReason(taskDescription) + failReason match { + case ef: ExceptionFailure => + assert(ef.exception.isDefined) + assert(ef.exception.get.getMessage() === "failure in deserialization") + case _ => + fail("unexpected failure type: $failReason") } } + private def mockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { + val mockEnv = mock[SparkEnv] + val mockRpcEnv = mock[RpcEnv] + val mockMetricsSystem = mock[MetricsSystem] + val mockMemoryManager = mock[MemoryManager] + when(mockEnv.conf).thenReturn(conf) + when(mockEnv.serializer).thenReturn(serializer) + when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) + when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) + when(mockEnv.memoryManager).thenReturn(mockMemoryManager) + when(mockEnv.closureSerializer).thenReturn(serializer) + SparkEnv.set(mockEnv) + mockEnv + } + private def fakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = { new TaskDescription( taskId = 0, @@ -197,6 +198,36 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug serializedTask) } + private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { + val mockBackend = mock[ExecutorBackend] + var executor: Executor = null + try { + executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true) + // the task will be launched in a dedicated worker thread + executor.launchTask(mockBackend, taskDescription) + val startTime = System.currentTimeMillis() + val maxTime = startTime + 5000 + while (executor.numRunningTasks > 0 && System.currentTimeMillis() < maxTime) { + Thread.sleep(10) + } + assert(executor.numRunningTasks === 0) + } finally { + if (executor != null) { + executor.stop() + } + } + val orderedMock = inOrder(mockBackend) + val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + orderedMock.verify(mockBackend) + .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) + orderedMock.verify(mockBackend) + .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) + // first statusUpdate for RUNNING has empty data + assert(statusCaptor.getAllValues().get(0).remaining() === 0) + // second update is more interesting + val failureData = statusCaptor.getAllValues.get(1) + SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData) + } } class FakeShuffleRDD(sc: SparkContext) extends RDD[Int](sc, Nil) { @@ -251,3 +282,10 @@ private class ExecutorSuiteHelper { @volatile var taskState: TaskState = _ @volatile var testFailedReason: TaskFailedReason = _ } + +private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable { + def writeExternal(out: ObjectOutput): Unit = {} + def readExternal(in: ObjectInput): Unit = { + throw new RuntimeException("failure in deserialization") + } +} From 4494673dd02efa1aa3b0dea79c6bd7a2e51111d4 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 19 Jan 2017 13:55:16 -0600 Subject: [PATCH 05/12] fix use of LocalSparkContext --- .../test/scala/org/apache/spark/executor/ExecutorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index cd708eee6624..ecd97080d1d5 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -126,7 +126,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug test("SPARK-19276: Handle Fetch Failed for all intervening user code") { val conf = new SparkConf().setMaster("local").setAppName("executor suite test") - val sc = new SparkContext(conf) + sc = new SparkContext(conf) val serializer = SparkEnv.get.closureSerializer.newInstance() val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size From 14f5125f78ee8accac570603ea31c2692641935f Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 14 Feb 2017 10:32:20 -0600 Subject: [PATCH 06/12] review feedback --- .../org/apache/spark/TaskContextImpl.scala | 6 +-- .../org/apache/spark/executor/Executor.scala | 17 ++++---- .../org/apache/spark/scheduler/Task.scala | 4 +- .../apache/spark/executor/ExecutorSuite.scala | 41 ++++++++++--------- 4 files changed, 35 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index d9cf48dbccc5..dc0d12878550 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -59,7 +59,7 @@ private[spark] class TaskContextImpl( // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't // hide the exception. See SPARK-19276 - @volatile private var _fetchFailed: Option[FetchFailedException] = None + @volatile private var _fetchFailedException: Option[FetchFailedException] = None override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { onCompleteCallbacks += listener @@ -132,9 +132,9 @@ private[spark] class TaskContextImpl( } private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = { - this._fetchFailed = Some(fetchFailed) + this._fetchFailedException = Option(fetchFailed) } - private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailed + private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 025d2d74704c..3d29a1b9bbeb 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -348,7 +348,7 @@ private[spark] class Executor( // unlikely). So we will log an error and keep going. logError(s"TID ${taskId} completed successfully though internally it encountered " + s"unrecoverable fetch failures! Most likely this means user code is incorrectly " + - s"swallowing Spark's internal exceptions", fetchFailure) + s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure) } val taskFinish = System.currentTimeMillis() val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { @@ -410,15 +410,16 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { - case ffe: FetchFailedException => - val reason = ffe.toTaskFailedReason - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - case t: Throwable if hasFetchFailure => - // tbere was a fetch failure in the task, but some user code wrapped that exception - // and threw something else. Regardless, we treat it as a fetch failure. val reason = task.context.fetchFailed.get.toTaskFailedReason + if (!t.isInstanceOf[FetchFailedException]) { + // there was a fetch failure in the task, but some user code wrapped that exception + // and threw something else. Regardless, we treat it as a fetch failure. + logWarning(s"TID ${taskId} encountered a ${classOf[FetchFailedException]} and " + + s"failed, but did not directly throw the ${classOf[FetchFailedException]}. " + + s"Spark is still handling the fetch failure, but these exceptions should not be " + + s"intercepted by user code.") + } setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 189cb7c849cf..70213722aae4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -132,8 +132,8 @@ private[spark] abstract class Task[T]( memoryManager.synchronized { memoryManager.notifyAll() } } } finally { - // though we unset the ThreadLocal here, the context itself is still queried directly - // in the TaskRunner to check for FetchFailedExceptions + // Though we unset the ThreadLocal here, the context member variable itself is still queried + // directly in the TaskRunner to check for FetchFailedExceptions. TaskContext.unset() } } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index ecd97080d1d5..4079a168465d 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -23,12 +23,14 @@ import java.util.Properties import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.Map +import scala.concurrent.duration._ import org.mockito.ArgumentCaptor import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.{inOrder, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.concurrent.Eventually import org.scalatest.mock.MockitoSugar import org.apache.spark._ @@ -42,15 +44,15 @@ import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId -class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { +class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually { test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") { // mock some objects to make Executor.launchTask() happy val conf = new SparkConf val serializer = new JavaSerializer(conf) - val env = mockEnv(conf, serializer) + val env = createMockEnv(conf, serializer) val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0)) - val taskDescription = fakeTaskDescription(serializedTask) + val taskDescription = createFakeTaskDescription(serializedTask) // we use latches to force the program to run in this order: // +-----------------------------+---------------------------------------+ @@ -124,13 +126,16 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } } - test("SPARK-19276: Handle Fetch Failed for all intervening user code") { + test("SPARK-19276: Handle FetchFailedExceptions that are hidden by user exceptions") { val conf = new SparkConf().setMaster("local").setAppName("executor suite test") sc = new SparkContext(conf) - val serializer = SparkEnv.get.closureSerializer.newInstance() val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size - val inputRDD = new FakeShuffleRDD(sc) + + // Submit a job where a fetch failure is thrown, but user code has a try/catch which hides + // the fetch failure. The executor should still tell the driver that the task failed due to a + // fetch failure, not a generic exception from user code. + val inputRDD = new FetchFailureThrowingRDD(sc) val secondRDD = new FetchFailureHidingRDD(sc, inputRDD) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() @@ -146,8 +151,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug ) val serTask = serializer.serialize(task) - val taskDescription = fakeTaskDescription(serTask) - + val taskDescription = createFakeTaskDescription(serTask) val failReason = runTaskAndGetFailReason(taskDescription) assert(failReason.isInstanceOf[FetchFailed]) @@ -156,9 +160,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug test("Gracefully handle error in task deserialization") { val conf = new SparkConf val serializer = new JavaSerializer(conf) - val env = mockEnv(conf, serializer) + val env = createMockEnv(conf, serializer) val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask) - val taskDescription = fakeTaskDescription(serializedTask) + val taskDescription = createFakeTaskDescription(serializedTask) val failReason = runTaskAndGetFailReason(taskDescription) failReason match { @@ -166,11 +170,11 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug assert(ef.exception.isDefined) assert(ef.exception.get.getMessage() === "failure in deserialization") case _ => - fail("unexpected failure type: $failReason") + fail(s"unexpected failure type: $failReason") } } - private def mockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { + private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { val mockEnv = mock[SparkEnv] val mockRpcEnv = mock[RpcEnv] val mockMetricsSystem = mock[MetricsSystem] @@ -185,7 +189,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug mockEnv } - private def fakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = { + private def createFakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = { new TaskDescription( taskId = 0, attemptNumber = 0, @@ -205,12 +209,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true) // the task will be launched in a dedicated worker thread executor.launchTask(mockBackend, taskDescription) - val startTime = System.currentTimeMillis() - val maxTime = startTime + 5000 - while (executor.numRunningTasks > 0 && System.currentTimeMillis() < maxTime) { - Thread.sleep(10) + eventually(timeout(5 seconds), interval(10 milliseconds)) { + assert(executor.numRunningTasks === 0) } - assert(executor.numRunningTasks === 0) } finally { if (executor != null) { executor.stop() @@ -230,7 +231,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } } -class FakeShuffleRDD(sc: SparkContext) extends RDD[Int](sc, Nil) { +class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) { override def compute(split: Partition, context: TaskContext): Iterator[Int] = { new Iterator[Int] { override def hasNext: Boolean = true @@ -256,7 +257,7 @@ class SimplePartition extends Partition { class FetchFailureHidingRDD( sc: SparkContext, - val input: FakeShuffleRDD) extends RDD[Int](input) { + val input: FetchFailureThrowingRDD) extends RDD[Int](input) { override def compute(split: Partition, context: TaskContext): Iterator[Int] = { val inItr = input.compute(split, context) try { From 08491c5b52555924431571dfe8beb694d5721161 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 14 Feb 2017 11:14:20 -0600 Subject: [PATCH 07/12] move task deserialization case into its own issue --- .../apache/spark/executor/ExecutorSuite.scala | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 4079a168465d..4e44c91fd313 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -157,23 +157,6 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug assert(failReason.isInstanceOf[FetchFailed]) } - test("Gracefully handle error in task deserialization") { - val conf = new SparkConf - val serializer = new JavaSerializer(conf) - val env = createMockEnv(conf, serializer) - val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask) - val taskDescription = createFakeTaskDescription(serializedTask) - - val failReason = runTaskAndGetFailReason(taskDescription) - failReason match { - case ef: ExceptionFailure => - assert(ef.exception.isDefined) - assert(ef.exception.get.getMessage() === "failure in deserialization") - case _ => - fail(s"unexpected failure type: $failReason") - } - } - private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { val mockEnv = mock[SparkEnv] val mockRpcEnv = mock[RpcEnv] From 2a497057616951ec066b1bbf10e0f7328e4d8572 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 28 Feb 2017 10:45:51 -0600 Subject: [PATCH 08/12] review feedback --- .../main/scala/org/apache/spark/executor/Executor.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 2f586b0f725f..1ceea5e85424 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -415,10 +415,11 @@ private[spark] class Executor( if (!t.isInstanceOf[FetchFailedException]) { // there was a fetch failure in the task, but some user code wrapped that exception // and threw something else. Regardless, we treat it as a fetch failure. - logWarning(s"TID ${taskId} encountered a ${classOf[FetchFailedException]} and " + - s"failed, but did not directly throw the ${classOf[FetchFailedException]}. " + - s"Spark is still handling the fetch failure, but these exceptions should not be " + - s"intercepted by user code.") + val fetchFailedCls = classOf[FetchFailedException].getName + logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " + + s"failed, but the ${fetchFailedCls} was hidden by another " + + s"exception. Spark is handling this like a fetch failure and ignoring the " + + s"other exception: $t") } setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) From 84eae146daf6ae89aa997d272714aeeb98eb818c Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 28 Feb 2017 10:46:05 -0600 Subject: [PATCH 09/12] review feedback --- core/src/main/scala/org/apache/spark/executor/Executor.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 1ceea5e85424..9057dd171ddc 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -410,7 +410,7 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { - case t: Throwable if hasFetchFailure => + case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => val reason = task.context.fetchFailed.get.toTaskFailedReason if (!t.isInstanceOf[FetchFailedException]) { // there was a fetch failure in the task, but some user code wrapped that exception From bee562174b284f5985f08cdeb9ced985970a01d6 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 28 Feb 2017 11:02:38 -0600 Subject: [PATCH 10/12] unit test for OOM with a fetchfailure --- .../org/apache/spark/executor/Executor.scala | 8 ++- .../apache/spark/executor/ExecutorSuite.scala | 67 +++++++++++++++++-- 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9057dd171ddc..790c1ae94247 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import java.io.{File, NotSerializableException} +import java.lang.Thread.UncaughtExceptionHandler import java.lang.management.ManagementFactory import java.net.{URI, URL} import java.nio.ByteBuffer @@ -52,7 +53,8 @@ private[spark] class Executor( executorHostname: String, env: SparkEnv, userClassPath: Seq[URL] = Nil, - isLocal: Boolean = false) + isLocal: Boolean = false, + uncaughtExceptionHandler: UncaughtExceptionHandler = SparkUncaughtExceptionHandler) extends Logging { logInfo(s"Starting executor ID $executorId on host $executorHostname") @@ -78,7 +80,7 @@ private[spark] class Executor( // Setup an uncaught exception handler for non-local mode. // Make any thread terminations due to uncaught exceptions kill the entire // executor process to avoid surprising stalls. - Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler) + Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler) } // Start worker thread pool @@ -472,7 +474,7 @@ private[spark] class Executor( // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. if (Utils.isFatalError(t)) { - SparkUncaughtExceptionHandler.uncaughtException(t) + uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t) } } finally { diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index e16fb0287f8b..b731ec92d9e3 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.lang.Thread.UncaughtExceptionHandler import java.nio.ByteBuffer import java.util.Properties import java.util.concurrent.{CountDownLatch, TimeUnit} @@ -27,7 +28,7 @@ import scala.concurrent.duration._ import org.mockito.ArgumentCaptor import org.mockito.Matchers.{any, eq => meq} -import org.mockito.Mockito.{inOrder, when} +import org.mockito.Mockito.{inOrder, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.concurrent.Eventually @@ -136,7 +137,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug // the fetch failure. The executor should still tell the driver that the task failed due to a // fetch failure, not a generic exception from user code. val inputRDD = new FetchFailureThrowingRDD(sc) - val secondRDD = new FetchFailureHidingRDD(sc, inputRDD) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() val task = new ResultTask( @@ -157,6 +158,44 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug assert(failReason.isInstanceOf[FetchFailed]) } + test("SPARK-19276: OOMs correctly handled with a FetchFailure") { + // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it + // may be a false positive. And we should call the uncaught exception handler. + val conf = new SparkConf().setMaster("local").setAppName("executor suite test") + sc = new SparkContext(conf) + val serializer = SparkEnv.get.closureSerializer.newInstance() + val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size + + // Submit a job where a fetch failure is thrown, but user code has a try/catch which hides + // the fetch failure. The executor should still tell the driver that the task failed due to a + // fetch failure, not a generic exception from user code. + val inputRDD = new FetchFailureThrowingRDD(sc) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true) + val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) + val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() + val task = new ResultTask( + stageId = 1, + stageAttemptId = 0, + taskBinary = taskBinary, + partition = secondRDD.partitions(0), + locs = Seq(), + outputId = 0, + localProperties = new Properties(), + serializedTaskMetrics = serializedTaskMetrics + ) + + val serTask = serializer.serialize(task) + val taskDescription = createFakeTaskDescription(serTask) + + val (failReason, uncaughtExceptionHandler) = + runTaskGetFailReasonAndExceptionHandler(taskDescription) + assert(failReason.isInstanceOf[ExceptionFailure]) + val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) + verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) + assert(exceptionCaptor.getAllValues.size === 1) + assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError]) + } + test("Gracefully handle error in task deserialization") { val conf = new SparkConf val serializer = new JavaSerializer(conf) @@ -203,13 +242,20 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { + runTaskGetFailReasonAndExceptionHandler(taskDescription)._1 + } + + private def runTaskGetFailReasonAndExceptionHandler( + taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = { val mockBackend = mock[ExecutorBackend] + val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] var executor: Executor = null try { - executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true) + executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, + uncaughtExceptionHandler = mockUncaughtExceptionHandler) // the task will be launched in a dedicated worker thread executor.launchTask(mockBackend, taskDescription) - eventually(timeout(5 seconds), interval(10 milliseconds)) { + eventually(timeout(5.seconds), interval(10.milliseconds)) { assert(executor.numRunningTasks === 0) } } finally { @@ -227,7 +273,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug assert(statusCaptor.getAllValues().get(0).remaining() === 0) // second update is more interesting val failureData = statusCaptor.getAllValues.get(1) - SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData) + val failReason = + SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData) + (failReason, mockUncaughtExceptionHandler) } } @@ -257,14 +305,19 @@ class SimplePartition extends Partition { class FetchFailureHidingRDD( sc: SparkContext, - val input: FetchFailureThrowingRDD) extends RDD[Int](input) { + val input: FetchFailureThrowingRDD, + throwOOM: Boolean) extends RDD[Int](input) { override def compute(split: Partition, context: TaskContext): Iterator[Int] = { val inItr = input.compute(split, context) try { Iterator(inItr.size) } catch { case t: Throwable => - throw new RuntimeException("User Exception that hides the original exception", t) + if (throwOOM) { + throw new OutOfMemoryError("OOM while handling another exception") + } else { + throw new RuntimeException("User Exception that hides the original exception", t) + } } } From ad47611e47210d86c3d8609414b2d2c1b59ececd Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 28 Feb 2017 11:11:25 -0600 Subject: [PATCH 11/12] reword comment --- .../scala/org/apache/spark/shuffle/FetchFailedException.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index a94ee450ae43..265a8acfa8d6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -52,8 +52,8 @@ private[spark] class FetchFailedException( // SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code // which intercepts this exception (possibly wrapping it), the Executor can still tell there was - // a fetch failure, and send the correct error msg back to the driver. The TaskContext won't be - // defined if this is run on the driver (just in test cases) -- we can safely ignore then. + // a fetch failure, and send the correct error msg back to the driver. We wrap with an Option + // because the TaskContext is not defined in some test cases. Option(TaskContext.get()).map(_.setFetchFailed(this)) def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, From 965506a1acdea1ede67df24605cc361edd155d7a Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 2 Mar 2017 11:47:31 -0600 Subject: [PATCH 12/12] update comments --- .../scala/org/apache/spark/executor/ExecutorSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index b731ec92d9e3..8150fff2d018 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -166,9 +166,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val serializer = SparkEnv.get.closureSerializer.newInstance() val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size - // Submit a job where a fetch failure is thrown, but user code has a try/catch which hides - // the fetch failure. The executor should still tell the driver that the task failed due to a - // fetch failure, not a generic exception from user code. + // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat + // the fetch failure as a false positive, and just do normal OOM handling. val inputRDD = new FetchFailureThrowingRDD(sc) val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) @@ -189,6 +188,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val (failReason, uncaughtExceptionHandler) = runTaskGetFailReasonAndExceptionHandler(taskDescription) + // make sure the task failure just looks like a OOM, not a fetch failure assert(failReason.isInstanceOf[ExceptionFailure]) val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())