-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19276][CORE] Fetch Failure handling robust to user error handling #16639
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
b93c37f
9635980
0a60aef
bbef893
730fd83
4494673
14f5125
08491c5
7840480
22da707
2a49705
84eae14
bee5621
ad47611
965506a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -148,6 +148,8 @@ private[spark] class Executor( | |
|
|
||
| startDriverHeartbeater() | ||
|
|
||
| private[executor] def numRunningTasks: Int = runningTasks.size() | ||
|
|
||
| def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { | ||
| val tr = new TaskRunner(context, taskDescription) | ||
| runningTasks.put(taskDescription.taskId, tr) | ||
|
|
@@ -340,6 +342,14 @@ private[spark] class Executor( | |
| } | ||
| } | ||
| } | ||
| task.context.fetchFailed.foreach { fetchFailure => | ||
| // uh-oh. it appears the user code has caught the fetch-failure without throwing any | ||
| // other exceptions. Its *possible* this is what the user meant to do (though highly | ||
| // unlikely). So we will log an error and keep going. | ||
| logError(s"TID ${taskId} completed successfully though internally it encountered " + | ||
| s"unrecoverable fetch failures! Most likely this means user code is incorrectly " + | ||
| s"swallowing Spark's internal exceptions", fetchFailure) | ||
|
||
| } | ||
| val taskFinish = System.currentTimeMillis() | ||
| val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { | ||
| threadMXBean.getCurrentThreadCpuTime | ||
|
|
@@ -405,6 +415,13 @@ private[spark] class Executor( | |
| setTaskFinishedAndClearInterruptStatus() | ||
| execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) | ||
|
|
||
| case t: Throwable if task.context.fetchFailed.isDefined => | ||
|
||
| // tbere was a fetch failure in the task, but some user code wrapped that exception | ||
| // and threw something else. Regardless, we treat it as a fetch failure. | ||
| val reason = task.context.fetchFailed.get.toTaskFailedReason | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tiny nit: but does it make sense to store the taskFailedReason (rather than the actual exception) in the task context? |
||
| setTaskFinishedAndClearInterruptStatus() | ||
| execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Probably log a similar message as above ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you mean the msg I added about "TID ${taskId} completed successfully though internally it encountered unrecoverable fetch failures!"? I wouldn't think we'd want to log anything special here. I'm trying to make this a "normal" code path. The user is allowed to allowed to do this. (sparksql already does.) we could log a warning, but then this change should be accompanied by auditing the code and making sure we never do this ourselves.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, something along those lines ... |
||
|
|
||
| case _: TaskKilledException => | ||
| logInfo(s"Executor killed $taskName (TID $taskId)") | ||
| setTaskFinishedAndClearInterruptStatus() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,19 +17,14 @@ | |
|
|
||
| package org.apache.spark.scheduler | ||
|
|
||
| import java.io.{DataInputStream, DataOutputStream} | ||
| import java.nio.ByteBuffer | ||
| import java.util.Properties | ||
|
|
||
| import scala.collection.mutable | ||
| import scala.collection.mutable.HashMap | ||
|
|
||
| import org.apache.spark._ | ||
| import org.apache.spark.executor.TaskMetrics | ||
| import org.apache.spark.internal.config.APP_CALLER_CONTEXT | ||
| import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} | ||
| import org.apache.spark.metrics.MetricsSystem | ||
| import org.apache.spark.serializer.SerializerInstance | ||
| import org.apache.spark.util._ | ||
|
|
||
| /** | ||
|
|
@@ -137,6 +132,8 @@ private[spark] abstract class Task[T]( | |
| memoryManager.synchronized { memoryManager.notifyAll() } | ||
| } | ||
| } finally { | ||
| // though we unset the ThreadLocal here, the context itself is still queried directly | ||
|
||
| // in the TaskRunner to check for FetchFailedExceptions | ||
| TaskContext.unset() | ||
| } | ||
| } | ||
|
|
@@ -156,7 +153,7 @@ private[spark] abstract class Task[T]( | |
| var epoch: Long = -1 | ||
|
|
||
| // Task context, to be initialized in run(). | ||
| @transient protected var context: TaskContextImpl = _ | ||
| @transient var context: TaskContextImpl = _ | ||
|
|
||
| // The actual Thread on which the task is running, if any. Initialized in run(). | ||
| @volatile @transient private var taskThread: Thread = _ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
|
|
||
| package org.apache.spark.shuffle | ||
|
|
||
| import org.apache.spark.{FetchFailed, TaskFailedReason} | ||
| import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason} | ||
| import org.apache.spark.storage.BlockManagerId | ||
| import org.apache.spark.util.Utils | ||
|
|
||
|
|
@@ -45,6 +45,12 @@ private[spark] class FetchFailedException( | |
| this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) | ||
| } | ||
|
|
||
| // SPARK-19267. We set the fetch failure in the task context, so that even if there is user-code | ||
| // which intercepts this exception (possibly wrapping it), the Executor can still tell there was | ||
| // a fetch failure, and send the correct error msg back to the driver. The TaskContext won't be | ||
| // defined if this is run on the driver (just in test cases) -- we can safely ignore then. | ||
|
||
| Option(TaskContext.get()).map(_.setFetchFailed(this)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since creation of an Exception does not necessarily mean it should get thrown - we must explicitly add this expectation to the documentation/contract of FetchFailedException constructor - indicating that we expect it to be created only for it to be thrown immediately.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, good point. I added to the docs, does it look OK? I also considered making the call to |
||
|
|
||
| def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, | ||
| Utils.exceptionString(this)) | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,29 +23,34 @@ import java.util.concurrent.CountDownLatch | |
|
|
||
| import scala.collection.mutable.Map | ||
|
|
||
| import org.mockito.Matchers._ | ||
| import org.mockito.Mockito.{mock, when} | ||
| import org.mockito.ArgumentCaptor | ||
| import org.mockito.Matchers.{any, eq => meq} | ||
| import org.mockito.Mockito.{inOrder, when} | ||
| import org.mockito.invocation.InvocationOnMock | ||
| import org.mockito.stubbing.Answer | ||
| import org.scalatest.mock.MockitoSugar | ||
|
|
||
| import org.apache.spark._ | ||
| import org.apache.spark.TaskState.TaskState | ||
| import org.apache.spark.memory.MemoryManager | ||
| import org.apache.spark.metrics.MetricsSystem | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.rpc.RpcEnv | ||
| import org.apache.spark.scheduler.{FakeTask, TaskDescription} | ||
| import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription} | ||
| import org.apache.spark.serializer.JavaSerializer | ||
| import org.apache.spark.shuffle.FetchFailedException | ||
| import org.apache.spark.storage.BlockManagerId | ||
|
|
||
| class ExecutorSuite extends SparkFunSuite { | ||
| class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { | ||
|
|
||
| test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") { | ||
| // mock some objects to make Executor.launchTask() happy | ||
| val conf = new SparkConf | ||
| val serializer = new JavaSerializer(conf) | ||
| val mockEnv = mock(classOf[SparkEnv]) | ||
| val mockRpcEnv = mock(classOf[RpcEnv]) | ||
| val mockMetricsSystem = mock(classOf[MetricsSystem]) | ||
| val mockMemoryManager = mock(classOf[MemoryManager]) | ||
| val mockEnv = mock[SparkEnv] | ||
| val mockRpcEnv = mock[RpcEnv] | ||
| val mockMetricsSystem = mock[MetricsSystem] | ||
| val mockMemoryManager = mock[MemoryManager] | ||
| when(mockEnv.conf).thenReturn(conf) | ||
| when(mockEnv.serializer).thenReturn(serializer) | ||
| when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) | ||
|
|
@@ -55,16 +60,7 @@ class ExecutorSuite extends SparkFunSuite { | |
| val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array() | ||
| val serializedTask = serializer.newInstance().serialize( | ||
| new FakeTask(0, 0, Nil, fakeTaskMetrics)) | ||
| val taskDescription = new TaskDescription( | ||
| taskId = 0, | ||
| attemptNumber = 0, | ||
| executorId = "", | ||
| name = "", | ||
| index = 0, | ||
| addedFiles = Map[String, Long](), | ||
| addedJars = Map[String, Long](), | ||
| properties = new Properties, | ||
| serializedTask) | ||
| val taskDescription = fakeTaskDescription(serializedTask) | ||
|
|
||
| // we use latches to force the program to run in this order: | ||
| // +-----------------------------+---------------------------------------+ | ||
|
|
@@ -86,7 +82,7 @@ class ExecutorSuite extends SparkFunSuite { | |
|
|
||
| val executorSuiteHelper = new ExecutorSuiteHelper | ||
|
|
||
| val mockExecutorBackend = mock(classOf[ExecutorBackend]) | ||
| val mockExecutorBackend = mock[ExecutorBackend] | ||
| when(mockExecutorBackend.statusUpdate(any(), any(), any())) | ||
| .thenAnswer(new Answer[Unit] { | ||
| var firstTime = true | ||
|
|
@@ -133,6 +129,116 @@ class ExecutorSuite extends SparkFunSuite { | |
| } | ||
| } | ||
| } | ||
|
|
||
| test("SPARK-19276: Handle Fetch Failed for all intervening user code") { | ||
|
||
| val conf = new SparkConf().setMaster("local").setAppName("executor suite test") | ||
| val sc = new SparkContext(conf) | ||
|
|
||
| val serializer = SparkEnv.get.closureSerializer.newInstance() | ||
| val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size | ||
| val inputRDD = new FakeShuffleRDD(sc) | ||
| val secondRDD = new FetchFailureHidingRDD(sc, inputRDD) | ||
| val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) | ||
| val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() | ||
| val task = new ResultTask( | ||
| stageId = 1, | ||
| stageAttemptId = 0, | ||
| taskBinary = taskBinary, | ||
| partition = secondRDD.partitions(0), | ||
| locs = Seq(), | ||
| outputId = 0, | ||
| localProperties = new Properties(), | ||
| serializedTaskMetrics = serializedTaskMetrics | ||
| ) | ||
|
|
||
| val serTask = serializer.serialize(task) | ||
| val taskDescription = fakeTaskDescription(serTask) | ||
|
|
||
|
||
|
|
||
| val mockBackend = mock[ExecutorBackend] | ||
| var executor: Executor = null | ||
| try { | ||
| executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true) | ||
| executor.launchTask(mockBackend, taskDescription) | ||
| val startTime = System.currentTimeMillis() | ||
| val maxTime = startTime + 5000 | ||
| while (executor.numRunningTasks > 0 && System.currentTimeMillis() < maxTime) { | ||
|
||
| Thread.sleep(10) | ||
| } | ||
| val orderedMock = inOrder(mockBackend) | ||
| val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) | ||
| orderedMock.verify(mockBackend) | ||
| .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) | ||
| orderedMock.verify(mockBackend) | ||
| .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) | ||
| // first statusUpdate for RUNNING has empty data | ||
| assert(statusCaptor.getAllValues().get(0).remaining() === 0) | ||
| // second update is more interesting | ||
| val failureData = statusCaptor.getAllValues.get(1) | ||
| val failReason = serializer.deserialize[TaskFailedReason](failureData) | ||
| assert(failReason.isInstanceOf[FetchFailed]) | ||
| } finally { | ||
| if (executor != null) { | ||
| executor.stop() | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private def fakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = { | ||
| new TaskDescription( | ||
| taskId = 0, | ||
| attemptNumber = 0, | ||
| executorId = "", | ||
| name = "", | ||
| index = 0, | ||
| addedFiles = Map[String, Long](), | ||
| addedJars = Map[String, Long](), | ||
| properties = new Properties, | ||
| serializedTask) | ||
| } | ||
|
|
||
| } | ||
|
|
||
| class FakeShuffleRDD(sc: SparkContext) extends RDD[Int](sc, Nil) { | ||
|
||
| override def compute(split: Partition, context: TaskContext): Iterator[Int] = { | ||
| new Iterator[Int] { | ||
| override def hasNext: Boolean = true | ||
| override def next(): Int = { | ||
| throw new FetchFailedException( | ||
| bmAddress = BlockManagerId("1", "hostA", 1234), | ||
| shuffleId = 0, | ||
| mapId = 0, | ||
| reduceId = 0, | ||
| message = "fake fetch failure" | ||
| ) | ||
| } | ||
| } | ||
| } | ||
| override protected def getPartitions: Array[Partition] = { | ||
| Array(new SimplePartition) | ||
| } | ||
| } | ||
|
|
||
| class SimplePartition extends Partition { | ||
| override def index: Int = 0 | ||
| } | ||
|
|
||
| class FetchFailureHidingRDD( | ||
| sc: SparkContext, | ||
| val input: FakeShuffleRDD) extends RDD[Int](input) { | ||
| override def compute(split: Partition, context: TaskContext): Iterator[Int] = { | ||
| val inItr = input.compute(split, context) | ||
| try { | ||
| Iterator(inItr.size) | ||
| } catch { | ||
| case t: Throwable => | ||
| throw new RuntimeException("User Exception that hides the original exception", t) | ||
| } | ||
| } | ||
|
|
||
| override protected def getPartitions: Array[Partition] = { | ||
| Array(new SimplePartition) | ||
| } | ||
| } | ||
|
|
||
| // Helps to test("SPARK-15963") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@volatile private ?