Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 54 additions & 45 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.executor

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.lang.Thread.UncaughtExceptionHandler
import java.net.URL
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, TimeUnit}
Expand Down Expand Up @@ -53,7 +54,7 @@ import org.apache.spark.scheduler.{DirectTaskResult, FakeTask, ResultTask, Task,
import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockManager, BlockManagerId}
import org.apache.spark.util.{LongAccumulator, ThreadUtils, UninterruptibleThread}
import org.apache.spark.util.{LongAccumulator, SparkUncaughtExceptionHandler, ThreadUtils, UninterruptibleThread}

class ExecutorSuite extends SparkFunSuite
with LocalSparkContext with MockitoSugar with Eventually with PrivateMethodTester {
Expand All @@ -64,6 +65,33 @@ class ExecutorSuite extends SparkFunSuite
super.afterEach()
}

/**
* Creates an Executor with the provided arguments, is then passed to `f`
* and will be stopped after `f` returns.
*/
def withExecutor(
executorId: String,
executorHostname: String,
env: SparkEnv,
userClassPath: Seq[URL] = Nil,
isLocal: Boolean = true,
uncaughtExceptionHandler: UncaughtExceptionHandler
= new SparkUncaughtExceptionHandler,
resources: immutable.Map[String, ResourceInformation]
= immutable.Map.empty[String, ResourceInformation])(f: Executor => Unit): Unit = {
var executor: Executor = null
try {
executor = new Executor(executorId, executorHostname, env, userClassPath, isLocal,
uncaughtExceptionHandler, resources)

f(executor)
} finally {
if (executor != null) {
executor.stop()
}
}
}

test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") {
// mock some objects to make Executor.launchTask() happy
val conf = new SparkConf
Expand Down Expand Up @@ -116,10 +144,8 @@ class ExecutorSuite extends SparkFunSuite
}
})

var executor: Executor = null
try {
executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true,
resources = immutable.Map.empty[String, ResourceInformation])
withExecutor("id", "localhost", env) { executor =>

// the task will be launched in a dedicated worker thread
executor.launchTask(mockExecutorBackend, taskDescription)

Expand All @@ -139,11 +165,6 @@ class ExecutorSuite extends SparkFunSuite
assert(executorSuiteHelper.testFailedReason.toErrorString === "TaskKilled (test)")
assert(executorSuiteHelper.taskState === TaskState.KILLED)
}
finally {
if (executor != null) {
executor.stop()
}
}
}

test("SPARK-19276: Handle FetchFailedExceptions that are hidden by user exceptions") {
Expand Down Expand Up @@ -255,25 +276,24 @@ class ExecutorSuite extends SparkFunSuite
confs.foreach { case (k, v) => conf.set(k, v) }
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
val executor =
new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
resources = immutable.Map.empty[String, ResourceInformation])
val executorClass = classOf[Executor]

// Save all heartbeats sent into an ArrayBuffer for verification
val heartbeats = ArrayBuffer[Heartbeat]()
val mockReceiver = mock[RpcEndpointRef]
when(mockReceiver.askSync(any[Heartbeat], any[RpcTimeout])(any))
.thenAnswer((invocation: InvocationOnMock) => {
val args = invocation.getArguments()
heartbeats += args(0).asInstanceOf[Heartbeat]
HeartbeatResponse(false)
})
val receiverRef = executorClass.getDeclaredField("heartbeatReceiverRef")
receiverRef.setAccessible(true)
receiverRef.set(executor, mockReceiver)
withExecutor("id", "localhost", SparkEnv.get) { executor =>
val executorClass = classOf[Executor]

// Save all heartbeats sent into an ArrayBuffer for verification
val heartbeats = ArrayBuffer[Heartbeat]()
val mockReceiver = mock[RpcEndpointRef]
when(mockReceiver.askSync(any[Heartbeat], any[RpcTimeout])(any))
.thenAnswer((invocation: InvocationOnMock) => {
val args = invocation.getArguments()
heartbeats += args(0).asInstanceOf[Heartbeat]
HeartbeatResponse(false)
})
val receiverRef = executorClass.getDeclaredField("heartbeatReceiverRef")
receiverRef.setAccessible(true)
receiverRef.set(executor, mockReceiver)

f(executor, heartbeats)
f(executor, heartbeats)
}
}

private def heartbeatZeroAccumulatorUpdateTest(dropZeroMetrics: Boolean): Unit = {
Expand Down Expand Up @@ -354,10 +374,7 @@ class ExecutorSuite extends SparkFunSuite
val taskDescription = createResultTaskDescription(serializer, taskBinary, rdd, 0)

val mockBackend = mock[ExecutorBackend]
var executor: Executor = null
try {
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
resources = immutable.Map.empty[String, ResourceInformation])
withExecutor("id", "localhost", SparkEnv.get) { executor =>
executor.launchTask(mockBackend, taskDescription)

// Ensure that the executor's metricsPoller is polled so that values are recorded for
Expand All @@ -368,10 +385,6 @@ class ExecutorSuite extends SparkFunSuite
eventually(timeout(5.seconds), interval(10.milliseconds)) {
assert(executor.numRunningTasks === 0)
}
} finally {
if (executor != null) {
executor.stop()
}
}

// Verify that peak values for task metrics get sent in the TaskResult
Expand Down Expand Up @@ -535,12 +548,11 @@ class ExecutorSuite extends SparkFunSuite
poll: Boolean = false): (TaskFailedReason, UncaughtExceptionHandler) = {
val mockBackend = mock[ExecutorBackend]
val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler]
var executor: Executor = null
val timedOut = new AtomicBoolean(false)
try {
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
uncaughtExceptionHandler = mockUncaughtExceptionHandler,
resources = immutable.Map.empty[String, ResourceInformation])

withExecutor("id", "localhost", SparkEnv.get,
uncaughtExceptionHandler = mockUncaughtExceptionHandler) { executor =>

// the task will be launched in a dedicated worker thread
executor.launchTask(mockBackend, taskDescription)
if (killTask) {
Expand Down Expand Up @@ -573,11 +585,8 @@ class ExecutorSuite extends SparkFunSuite
assert(executor.numRunningTasks === 0)
}
assert(!timedOut.get(), "timed out waiting to be ready to kill tasks")
} finally {
if (executor != null) {
executor.stop()
}
}

val orderedMock = inOrder(mockBackend)
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
orderedMock.verify(mockBackend)
Expand Down