Skip to content
Closed
7 changes: 7 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener}


Expand Down Expand Up @@ -190,4 +191,10 @@ abstract class TaskContext extends Serializable {
*/
private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit

/**
* Record that this task has failed due to a fetch failure from a remote host. This allows
* fetch-failure handling to get triggered by the driver, regardless of intervening user-code.
*/
private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit

}
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._

private[spark] class TaskContextImpl(
Expand Down Expand Up @@ -56,6 +57,10 @@ private[spark] class TaskContextImpl(
// Whether the task has failed.
@volatile private var failed: Boolean = false

// If there was a fetch failure in the task, we store it here, to make sure user-code doesn't
// hide the exception. See SPARK-19276
@volatile private var _fetchFailedException: Option[FetchFailedException] = None

override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
onCompleteCallbacks += listener
this
Expand Down Expand Up @@ -126,4 +131,10 @@ private[spark] class TaskContextImpl(
taskMetrics.registerAccumulator(a)
}

private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = {
this._fetchFailedException = Option(fetchFailed)
}

private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException

}
26 changes: 24 additions & 2 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ private[spark] class Executor(

startDriverHeartbeater()

private[executor] def numRunningTasks: Int = runningTasks.size()

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
val tr = new TaskRunner(context, taskDescription)
runningTasks.put(taskDescription.taskId, tr)
Expand Down Expand Up @@ -340,6 +342,14 @@ private[spark] class Executor(
}
}
}
task.context.fetchFailed.foreach { fetchFailure =>
// uh-oh. it appears the user code has caught the fetch-failure without throwing any
// other exceptions. Its *possible* this is what the user meant to do (though highly
// unlikely). So we will log an error and keep going.
logError(s"TID ${taskId} completed successfully though internally it encountered " +
s"unrecoverable fetch failures! Most likely this means user code is incorrectly " +
s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
}
val taskFinish = System.currentTimeMillis()
val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
Expand Down Expand Up @@ -400,8 +410,16 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

} catch {
case ffe: FetchFailedException =>
val reason = ffe.toTaskFailedReason
case t: Throwable if hasFetchFailure =>
Copy link
Contributor

@kayousterhout kayousterhout Feb 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the above case be eliminated with the addition of this one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, as this is now, you could eliminate this -- I left it separate for now just to highlight that we can differentiate two special cases, which we could handle in a few different ways.

  1. FetchFailed is thrown, and the task fails, but its not the outer-most exception

It seems clear in this case, we should fail the task with a FetchFailure. But do we also want to log an error or something indicating bad user code? Kinda minor, but might be a good idea. (Suggested by @mridulm above as well, I think.)

1a) or the FetchFailed isn't part of the thrown exception at all.

As I mentioned in my response to your other question, I'd like to consider this exactly the same as (1).

  1. FetchFailed is thrown, but totally swallowed so the task succeeds

Should we succeed the task, or fail it? I don't really know how this would happen. It seems really unlikely the user meant to do this. But then again, maybe the user did? I chose to just log an error but still succeed the task. (@markhamstra commented about this on the jira as well.)

its pretty easy to change the code for whatever the desired behavior is, just waiting for a clear decision.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Mridul's comment on (1) (that it would be nice to log a warning in this case) and your assessment of 2. To handle (1), you could have just this one case, and then log a warning if !t.isInstanceOf[FetchFailedException]

val reason = task.context.fetchFailed.get.toTaskFailedReason
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit: but does it make sense to store the taskFailedReason (rather than the actual exception) in the task context?

if (!t.isInstanceOf[FetchFailedException]) {
// there was a fetch failure in the task, but some user code wrapped that exception
// and threw something else. Regardless, we treat it as a fetch failure.
logWarning(s"TID ${taskId} encountered a ${classOf[FetchFailedException]} and " +
s"failed, but did not directly throw the ${classOf[FetchFailedException]}. " +
s"Spark is still handling the fetch failure, but these exceptions should not be " +
s"intercepted by user code.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mridulm @kayousterhout how is this msg? open to other suggestions. I'm not sure exactly what to recommend to the user instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry that this is slightly misleading because there's not necessarily anything bad happening here (e.g., in the SQL case), and the user-thrown exception is getting permanently lost. What about something more like

 logWarning(s"TID ${taskId} encountered a ${classOf[FetchFailedException]} and " +
               s"failed, but the ${classOf[FetchFailedException]} was hidden by another " +
               s"exception: $t.  Spark is handling this like a fetch failure and ignoring the " +
               s"other exception.")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kayousterhout While I like the message, spark sql should not be catching that exception to begin with anyway.

Btw, the impact of ignoring the exception here is needs to be also considered ... "catch Throwable" block does some interesting things for accumulator updates, isFatalError.
Atleast the latter must be handled here (an OOM being raised for example) - not sure about accumulator updates ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @mridulm that it looks like these lines (473-475 below) need to be added here:

if (Utils.isFatalError(t)) {
             SparkUncaughtExceptionHandler.uncaughtException(t)		             
}

I'm less sure about the accumulator updates. It looks like the old code doesn't report accumulators for fetch failed exceptions, but it's not clear to me why we'd report them for some kinds of exceptions but not others. The simplest thing to do seems to be the current approach (since it roughly maintains the old behavior of not updating accumulators for fetch failures) but I don't have a good sense for why this is or is not correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I like that msg better. I changed it slightly so the original exception is at the end, otherwise its hard to tell where the original exception ends and you are back to the error msg. Here's what the new msg looks like from the test case now:

17/02/27 16:33:43.953 Executor task launch worker for task 0 WARN Executor: TID 0 encountered a org.apache.spark.shuffle.FetchFailedException and failed, but the org.apache.spark.shuffle.FetchFailedException was hidden by another exception.  Spark is handling this like a fetch failure and ignoring the other exception: java.lang.RuntimeException: User Exception that hides the original exception

You have a good point about the uncaught exception handler, I have added that back. I wondered whether I should add those lines inside the case t: Throwable if hasFetchFailure block, or make it a condition for the case itself case t: Throwable if hasFetchFailure && !Utils.isFatalError(t). I decided to make it part of the condition, since that is more like the old behavior, and a fetch failure that happens during an OOM may not be real.

I also looked into adding a unit test for this handling -- it requires some refactoring, potentially more work than its worth, so I put it in a separate commit.

I'd rather avoid changing the behavior for accumulators here. Accumulators have such weird semantics its not clear what they should do, we can fix that separately if we really want to.

}
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Probably log a similar message as above ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean the msg I added about "TID ${taskId} completed successfully though internally it encountered unrecoverable fetch failures!"? I wouldn't think we'd want to log anything special here. I'm trying to make this a "normal" code path. The user is allowed to allowed to do this. (sparksql already does.)

we could log a warning, but then this change should be accompanied by auditing the code and making sure we never do this ourselves.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, something along those lines ...
And I agree, we should not be doing this ourselves as well.


Expand Down Expand Up @@ -460,6 +478,10 @@ private[spark] class Executor(
runningTasks.remove(taskId)
}
}

private def hasFetchFailure: Boolean = {
task != null && task.context != null && task.context.fetchFailed.isDefined
}
}

/**
Expand Down
9 changes: 3 additions & 6 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,14 @@

package org.apache.spark.scheduler

import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
import java.util.Properties

import scala.collection.mutable
import scala.collection.mutable.HashMap

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util._

/**
Expand Down Expand Up @@ -137,6 +132,8 @@ private[spark] abstract class Task[T](
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
// Though we unset the ThreadLocal here, the context member variable itself is still queried
// directly in the TaskRunner to check for FetchFailedExceptions.
TaskContext.unset()
}
}
Expand All @@ -156,7 +153,7 @@ private[spark] abstract class Task[T](
var epoch: Long = -1

// Task context, to be initialized in run().
@transient protected var context: TaskContextImpl = _
@transient var context: TaskContextImpl = _

// The actual Thread on which the task is running, if any. Initialized in run().
@volatile @transient private var taskThread: Thread = _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.shuffle

import org.apache.spark.{FetchFailed, TaskFailedReason}
import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason}
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils

Expand All @@ -26,6 +26,11 @@ import org.apache.spark.util.Utils
* back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage.
*
* Note that bmAddress can be null.
*
* To prevent user code from hiding this fetch failure, in the constructor we call
* [[TaskContext.setFetchFailed()]]. This means that you *must* throw this exception immediately
* after creating it -- you cannot create it, check some condition, and then decide to ignore it
* (or risk triggering any other exceptions). See SPARK-19276.
*/
private[spark] class FetchFailedException(
bmAddress: BlockManagerId,
Expand All @@ -45,6 +50,12 @@ private[spark] class FetchFailedException(
this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
}

// SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code
// which intercepts this exception (possibly wrapping it), the Executor can still tell there was
// a fetch failure, and send the correct error msg back to the driver. The TaskContext won't be
// defined if this is run on the driver (just in test cases) -- we can safely ignore then.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This last sentence is confusing. A task that runs locally on the driver can still hit fetch failures right? Or are you saying the TaskContext will only be not defined in test cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I've reworded this. The issue is that we have test cases where the TaskContext isn't defined, and so we'd hit an NPE without the Option wrapper. But in general, the TaskContext should always be defined anytime we'd create a FetchFailure.

The alternative would be to track down the test cases w/out a TaskContext, and add one back.

Option(TaskContext.get()).map(_.setFetchFailed(this))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since creation of an Exception does not necessarily mean it should get thrown - we must explicitly add this expectation to the documentation/contract of FetchFailedException constructor - indicating that we expect it to be created only for it to be thrown immediately.
This should be fine since FetchFailedException is private[spark] right now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good point. I added to the docs, does it look OK?

I also considered making the call to TaskContext.setFetchFailed live outside of the constructor, so at each site it was created, it would have to be called -- but I thought that seemed more dangerous.


def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
Utils.exceptionString(this))
}
Expand Down
Loading