Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ private[spark] class Executor(

startDriverHeartbeater()

private[executor] def numRunningTasks: Int = runningTasks.size()

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
val tr = new TaskRunner(context, taskDescription)
runningTasks.put(taskDescription.taskId, tr)
Expand Down
139 changes: 106 additions & 33 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@

package org.apache.spark.executor

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.CountDownLatch
import java.util.concurrent.{CountDownLatch, TimeUnit}

import scala.collection.mutable.Map
import scala.concurrent.duration._

import org.mockito.Matchers._
import org.mockito.Mockito.{mock, when}
import org.mockito.ArgumentCaptor
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.{inOrder, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.concurrent.Eventually
import org.scalatest.mock.MockitoSugar

import org.apache.spark._
import org.apache.spark.TaskState.TaskState
Expand All @@ -36,35 +41,15 @@ import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.{FakeTask, TaskDescription}
import org.apache.spark.serializer.JavaSerializer

class ExecutorSuite extends SparkFunSuite {
class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {

test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") {
// mock some objects to make Executor.launchTask() happy
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
val mockEnv = mock(classOf[SparkEnv])
val mockRpcEnv = mock(classOf[RpcEnv])
val mockMetricsSystem = mock(classOf[MetricsSystem])
val mockMemoryManager = mock(classOf[MemoryManager])
when(mockEnv.conf).thenReturn(conf)
when(mockEnv.serializer).thenReturn(serializer)
when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
when(mockEnv.closureSerializer).thenReturn(serializer)
val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array()
val serializedTask = serializer.newInstance().serialize(
new FakeTask(0, 0, Nil, fakeTaskMetrics))
val taskDescription = new TaskDescription(
taskId = 0,
attemptNumber = 0,
executorId = "",
name = "",
index = 0,
addedFiles = Map[String, Long](),
addedJars = Map[String, Long](),
properties = new Properties,
serializedTask)
val env = createMockEnv(conf, serializer)
val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0))
val taskDescription = createFakeTaskDescription(serializedTask)

// we use latches to force the program to run in this order:
// +-----------------------------+---------------------------------------+
Expand All @@ -86,7 +71,7 @@ class ExecutorSuite extends SparkFunSuite {

val executorSuiteHelper = new ExecutorSuiteHelper

val mockExecutorBackend = mock(classOf[ExecutorBackend])
val mockExecutorBackend = mock[ExecutorBackend]
when(mockExecutorBackend.statusUpdate(any(), any(), any()))
.thenAnswer(new Answer[Unit] {
var firstTime = true
Expand All @@ -102,8 +87,8 @@ class ExecutorSuite extends SparkFunSuite {
val taskState = invocationOnMock.getArguments()(1).asInstanceOf[TaskState]
executorSuiteHelper.taskState = taskState
val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer]
executorSuiteHelper.testFailedReason
= serializer.newInstance().deserialize(taskEndReason)
executorSuiteHelper.testFailedReason =
serializer.newInstance().deserialize(taskEndReason)
// let the main test thread check `taskState` and `testFailedReason`
executorSuiteHelper.latch3.countDown()
}
Expand All @@ -112,16 +97,20 @@ class ExecutorSuite extends SparkFunSuite {

var executor: Executor = null
try {
executor = new Executor("id", "localhost", mockEnv, userClassPath = Nil, isLocal = true)
executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockExecutorBackend, taskDescription)

executorSuiteHelper.latch1.await()
if (!executorSuiteHelper.latch1.await(5, TimeUnit.SECONDS)) {
fail("executor did not send first status update in time")
}
// we know the task will be started, but not yet deserialized, because of the latches we
// use in mockExecutorBackend.
executor.killAllTasks(true)
executorSuiteHelper.latch2.countDown()
executorSuiteHelper.latch3.await()
if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) {
fail("executor did not send second status update in time")
}

// `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED`
assert(executorSuiteHelper.testFailedReason === TaskKilled)
Expand All @@ -133,6 +122,79 @@ class ExecutorSuite extends SparkFunSuite {
}
}
}

test("Gracefully handle error in task deserialization") {
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you take the error message as a param to NonDeserializableTask, so the magic string can be defined once and used below on line 137? From a quick read I thought the one on line 137 was coming from the Spark core code (not from the test code)

val taskDescription = createFakeTaskDescription(serializedTask)

val failReason = runTaskAndGetFailReason(taskDescription)
failReason match {
case ef: ExceptionFailure =>
assert(ef.exception.isDefined)
assert(ef.exception.get.getMessage() === NonDeserializableTask.errorMsg)
case _ =>
fail(s"unexpected failure type: $failReason")
}
}

private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
val mockEnv = mock[SparkEnv]
val mockRpcEnv = mock[RpcEnv]
val mockMetricsSystem = mock[MetricsSystem]
val mockMemoryManager = mock[MemoryManager]
when(mockEnv.conf).thenReturn(conf)
when(mockEnv.serializer).thenReturn(serializer)
when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
when(mockEnv.closureSerializer).thenReturn(serializer)
SparkEnv.set(mockEnv)
mockEnv
}

private def createFakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = {
new TaskDescription(
taskId = 0,
attemptNumber = 0,
executorId = "",
name = "",
index = 0,
addedFiles = Map[String, Long](),
addedJars = Map[String, Long](),
properties = new Properties,
serializedTask)
}

private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
val mockBackend = mock[ExecutorBackend]
var executor: Executor = null
try {
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockBackend, taskDescription)
eventually(timeout(5 seconds), interval(10 milliseconds)) {
assert(executor.numRunningTasks === 0)
}
} finally {
if (executor != null) {
executor.stop()
}
}
val orderedMock = inOrder(mockBackend)
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
orderedMock.verify(mockBackend)
.statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture())
orderedMock.verify(mockBackend)
.statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture())
// first statusUpdate for RUNNING has empty data
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
// second update is more interesting
val failureData = statusCaptor.getAllValues.get(1)
SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
}
}

// Helps to test("SPARK-15963")
Expand All @@ -145,3 +207,14 @@ private class ExecutorSuiteHelper {
@volatile var taskState: TaskState = _
@volatile var testFailedReason: TaskFailedReason = _
}

private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable {
def writeExternal(out: ObjectOutput): Unit = {}
def readExternal(in: ObjectInput): Unit = {
throw new RuntimeException(NonDeserializableTask.errorMsg)
}
}

private object NonDeserializableTask {
val errorMsg = "failure in deserialization"
}