Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ private[spark] class Executor(

startDriverHeartbeater()

private[executor] def numRunningTasks: Int = runningTasks.size()

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
val tr = new TaskRunner(context, taskDescription)
runningTasks.put(taskDescription.taskId, tr)
Expand Down
135 changes: 102 additions & 33 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@

package org.apache.spark.executor

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.CountDownLatch
import java.util.concurrent.{CountDownLatch, TimeUnit}

import scala.collection.mutable.Map
import scala.concurrent.duration._

import org.mockito.Matchers._
import org.mockito.Mockito.{mock, when}
import org.mockito.ArgumentCaptor
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.{inOrder, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.concurrent.Eventually
import org.scalatest.mock.MockitoSugar

import org.apache.spark._
import org.apache.spark.TaskState.TaskState
Expand All @@ -36,35 +41,15 @@ import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.{FakeTask, TaskDescription}
import org.apache.spark.serializer.JavaSerializer

class ExecutorSuite extends SparkFunSuite {
class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {

test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") {
// mock some objects to make Executor.launchTask() happy
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
val mockEnv = mock(classOf[SparkEnv])
val mockRpcEnv = mock(classOf[RpcEnv])
val mockMetricsSystem = mock(classOf[MetricsSystem])
val mockMemoryManager = mock(classOf[MemoryManager])
when(mockEnv.conf).thenReturn(conf)
when(mockEnv.serializer).thenReturn(serializer)
when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
when(mockEnv.closureSerializer).thenReturn(serializer)
val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array()
val serializedTask = serializer.newInstance().serialize(
new FakeTask(0, 0, Nil, fakeTaskMetrics))
val taskDescription = new TaskDescription(
taskId = 0,
attemptNumber = 0,
executorId = "",
name = "",
index = 0,
addedFiles = Map[String, Long](),
addedJars = Map[String, Long](),
properties = new Properties,
serializedTask)
val env = createMockEnv(conf, serializer)
val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0))
val taskDescription = createFakeTaskDescription(serializedTask)

// we use latches to force the program to run in this order:
// +-----------------------------+---------------------------------------+
Expand All @@ -86,7 +71,7 @@ class ExecutorSuite extends SparkFunSuite {

val executorSuiteHelper = new ExecutorSuiteHelper

val mockExecutorBackend = mock(classOf[ExecutorBackend])
val mockExecutorBackend = mock[ExecutorBackend]
when(mockExecutorBackend.statusUpdate(any(), any(), any()))
.thenAnswer(new Answer[Unit] {
var firstTime = true
Expand All @@ -102,8 +87,8 @@ class ExecutorSuite extends SparkFunSuite {
val taskState = invocationOnMock.getArguments()(1).asInstanceOf[TaskState]
executorSuiteHelper.taskState = taskState
val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer]
executorSuiteHelper.testFailedReason
= serializer.newInstance().deserialize(taskEndReason)
executorSuiteHelper.testFailedReason =
serializer.newInstance().deserialize(taskEndReason)
// let the main test thread check `taskState` and `testFailedReason`
executorSuiteHelper.latch3.countDown()
}
Expand All @@ -112,16 +97,20 @@ class ExecutorSuite extends SparkFunSuite {

var executor: Executor = null
try {
executor = new Executor("id", "localhost", mockEnv, userClassPath = Nil, isLocal = true)
executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockExecutorBackend, taskDescription)

executorSuiteHelper.latch1.await()
if (!executorSuiteHelper.latch1.await(5, TimeUnit.SECONDS)) {
fail("executor did not send first status update in time")
}
// we know the task will be started, but not yet deserialized, because of the latches we
// use in mockExecutorBackend.
executor.killAllTasks(true)
executorSuiteHelper.latch2.countDown()
executorSuiteHelper.latch3.await()
if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) {
fail("executor did not send second status update in time")
}

// `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED`
assert(executorSuiteHelper.testFailedReason === TaskKilled)
Expand All @@ -133,6 +122,79 @@ class ExecutorSuite extends SparkFunSuite {
}
}
}

test("Gracefully handle error in task deserialization") {
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you take the error message as a param to NonDeserializableTask, so the magic string can be defined once and used below on line 137? From a quick read I thought the one on line 137 was coming from the Spark core code (not from the test code)

val taskDescription = createFakeTaskDescription(serializedTask)

val failReason = runTaskAndGetFailReason(taskDescription)
failReason match {
case ef: ExceptionFailure =>
assert(ef.exception.isDefined)
assert(ef.exception.get.getMessage() === "failure in deserialization")
case _ =>
fail(s"unexpected failure type: $failReason")
}
}

private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
val mockEnv = mock[SparkEnv]
val mockRpcEnv = mock[RpcEnv]
val mockMetricsSystem = mock[MetricsSystem]
val mockMemoryManager = mock[MemoryManager]
when(mockEnv.conf).thenReturn(conf)
when(mockEnv.serializer).thenReturn(serializer)
when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
when(mockEnv.closureSerializer).thenReturn(serializer)
SparkEnv.set(mockEnv)
mockEnv
}

private def createFakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = {
new TaskDescription(
taskId = 0,
attemptNumber = 0,
executorId = "",
name = "",
index = 0,
addedFiles = Map[String, Long](),
addedJars = Map[String, Long](),
properties = new Properties,
serializedTask)
}

private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
val mockBackend = mock[ExecutorBackend]
var executor: Executor = null
try {
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockBackend, taskDescription)
eventually(timeout(5 seconds), interval(10 milliseconds)) {
assert(executor.numRunningTasks === 0)
}
} finally {
if (executor != null) {
executor.stop()
}
}
val orderedMock = inOrder(mockBackend)
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
orderedMock.verify(mockBackend)
.statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture())
orderedMock.verify(mockBackend)
.statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture())
// first statusUpdate for RUNNING has empty data
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
// second update is more interesting
val failureData = statusCaptor.getAllValues.get(1)
SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
}
}

// Helps to test("SPARK-15963")
Expand All @@ -145,3 +207,10 @@ private class ExecutorSuiteHelper {
@volatile var taskState: TaskState = _
@volatile var testFailedReason: TaskFailedReason = _
}

private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable {
def writeExternal(out: ObjectOutput): Unit = {}
def readExternal(in: ObjectInput): Unit = {
throw new RuntimeException("failure in deserialization")
}
}