diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 7cf7a81a7613..97ffb36062db 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.executor import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.lang.Thread.UncaughtExceptionHandler +import java.net.URL import java.nio.ByteBuffer import java.util.Properties import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, TimeUnit} @@ -53,7 +54,7 @@ import org.apache.spark.scheduler.{DirectTaskResult, FakeTask, ResultTask, Task, import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockManager, BlockManagerId} -import org.apache.spark.util.{LongAccumulator, ThreadUtils, UninterruptibleThread} +import org.apache.spark.util.{LongAccumulator, SparkUncaughtExceptionHandler, ThreadUtils, UninterruptibleThread} class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually with PrivateMethodTester { @@ -64,6 +65,33 @@ class ExecutorSuite extends SparkFunSuite super.afterEach() } + /** + * Creates an Executor with the provided arguments, is then passed to `f` + * and will be stopped after `f` returns. + */ + def withExecutor( + executorId: String, + executorHostname: String, + env: SparkEnv, + userClassPath: Seq[URL] = Nil, + isLocal: Boolean = true, + uncaughtExceptionHandler: UncaughtExceptionHandler + = new SparkUncaughtExceptionHandler, + resources: immutable.Map[String, ResourceInformation] + = immutable.Map.empty[String, ResourceInformation])(f: Executor => Unit): Unit = { + var executor: Executor = null + try { + executor = new Executor(executorId, executorHostname, env, userClassPath, isLocal, + uncaughtExceptionHandler, resources) + + f(executor) + } finally { + if (executor != null) { + executor.stop() + } + } + } + test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") { // mock some objects to make Executor.launchTask() happy val conf = new SparkConf @@ -116,10 +144,8 @@ class ExecutorSuite extends SparkFunSuite } }) - var executor: Executor = null - try { - executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true, - resources = immutable.Map.empty[String, ResourceInformation]) + withExecutor("id", "localhost", env) { executor => + // the task will be launched in a dedicated worker thread executor.launchTask(mockExecutorBackend, taskDescription) @@ -139,11 +165,6 @@ class ExecutorSuite extends SparkFunSuite assert(executorSuiteHelper.testFailedReason.toErrorString === "TaskKilled (test)") assert(executorSuiteHelper.taskState === TaskState.KILLED) } - finally { - if (executor != null) { - executor.stop() - } - } } test("SPARK-19276: Handle FetchFailedExceptions that are hidden by user exceptions") { @@ -255,25 +276,24 @@ class ExecutorSuite extends SparkFunSuite confs.foreach { case (k, v) => conf.set(k, v) } val serializer = new JavaSerializer(conf) val env = createMockEnv(conf, serializer) - val executor = - new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, - resources = immutable.Map.empty[String, ResourceInformation]) - val executorClass = classOf[Executor] - - // Save all heartbeats sent into an ArrayBuffer for verification - val heartbeats = ArrayBuffer[Heartbeat]() - val mockReceiver = mock[RpcEndpointRef] - when(mockReceiver.askSync(any[Heartbeat], any[RpcTimeout])(any)) - .thenAnswer((invocation: InvocationOnMock) => { - val args = invocation.getArguments() - heartbeats += args(0).asInstanceOf[Heartbeat] - HeartbeatResponse(false) - }) - val receiverRef = executorClass.getDeclaredField("heartbeatReceiverRef") - receiverRef.setAccessible(true) - receiverRef.set(executor, mockReceiver) + withExecutor("id", "localhost", SparkEnv.get) { executor => + val executorClass = classOf[Executor] + + // Save all heartbeats sent into an ArrayBuffer for verification + val heartbeats = ArrayBuffer[Heartbeat]() + val mockReceiver = mock[RpcEndpointRef] + when(mockReceiver.askSync(any[Heartbeat], any[RpcTimeout])(any)) + .thenAnswer((invocation: InvocationOnMock) => { + val args = invocation.getArguments() + heartbeats += args(0).asInstanceOf[Heartbeat] + HeartbeatResponse(false) + }) + val receiverRef = executorClass.getDeclaredField("heartbeatReceiverRef") + receiverRef.setAccessible(true) + receiverRef.set(executor, mockReceiver) - f(executor, heartbeats) + f(executor, heartbeats) + } } private def heartbeatZeroAccumulatorUpdateTest(dropZeroMetrics: Boolean): Unit = { @@ -354,10 +374,7 @@ class ExecutorSuite extends SparkFunSuite val taskDescription = createResultTaskDescription(serializer, taskBinary, rdd, 0) val mockBackend = mock[ExecutorBackend] - var executor: Executor = null - try { - executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, - resources = immutable.Map.empty[String, ResourceInformation]) + withExecutor("id", "localhost", SparkEnv.get) { executor => executor.launchTask(mockBackend, taskDescription) // Ensure that the executor's metricsPoller is polled so that values are recorded for @@ -368,10 +385,6 @@ class ExecutorSuite extends SparkFunSuite eventually(timeout(5.seconds), interval(10.milliseconds)) { assert(executor.numRunningTasks === 0) } - } finally { - if (executor != null) { - executor.stop() - } } // Verify that peak values for task metrics get sent in the TaskResult @@ -535,12 +548,11 @@ class ExecutorSuite extends SparkFunSuite poll: Boolean = false): (TaskFailedReason, UncaughtExceptionHandler) = { val mockBackend = mock[ExecutorBackend] val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] - var executor: Executor = null val timedOut = new AtomicBoolean(false) - try { - executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, - uncaughtExceptionHandler = mockUncaughtExceptionHandler, - resources = immutable.Map.empty[String, ResourceInformation]) + + withExecutor("id", "localhost", SparkEnv.get, + uncaughtExceptionHandler = mockUncaughtExceptionHandler) { executor => + // the task will be launched in a dedicated worker thread executor.launchTask(mockBackend, taskDescription) if (killTask) { @@ -573,11 +585,8 @@ class ExecutorSuite extends SparkFunSuite assert(executor.numRunningTasks === 0) } assert(!timedOut.get(), "timed out waiting to be ready to kill tasks") - } finally { - if (executor != null) { - executor.stop() - } } + val orderedMock = inOrder(mockBackend) val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) orderedMock.verify(mockBackend)