@@ -19,6 +19,7 @@ package org.apache.spark.executor
1919
2020import java .io .{Externalizable , ObjectInput , ObjectOutput }
2121import java .lang .Thread .UncaughtExceptionHandler
22+ import java .net .URL
2223import java .nio .ByteBuffer
2324import java .util .Properties
2425import java .util .concurrent .{ConcurrentHashMap , CountDownLatch , TimeUnit }
@@ -53,7 +54,7 @@ import org.apache.spark.scheduler.{DirectTaskResult, FakeTask, ResultTask, Task,
5354import org .apache .spark .serializer .{JavaSerializer , SerializerInstance , SerializerManager }
5455import org .apache .spark .shuffle .FetchFailedException
5556import org .apache .spark .storage .{BlockManager , BlockManagerId }
56- import org .apache .spark .util .{LongAccumulator , ThreadUtils , UninterruptibleThread }
57+ import org .apache .spark .util .{LongAccumulator , SparkUncaughtExceptionHandler , ThreadUtils , UninterruptibleThread }
5758
5859class ExecutorSuite extends SparkFunSuite
5960 with LocalSparkContext with MockitoSugar with Eventually with PrivateMethodTester {
@@ -64,6 +65,33 @@ class ExecutorSuite extends SparkFunSuite
6465 super .afterEach()
6566 }
6667
68+ /**
69+ * Creates an Executor with the provided arguments, is then passed to `f`
70+ * and will be stopped after `f` returns.
71+ */
72+ def withExecutor (
73+ executorId : String ,
74+ executorHostname : String ,
75+ env : SparkEnv ,
76+ userClassPath : Seq [URL ] = Nil ,
77+ isLocal : Boolean = true ,
78+ uncaughtExceptionHandler : UncaughtExceptionHandler
79+ = new SparkUncaughtExceptionHandler ,
80+ resources : immutable.Map [String , ResourceInformation ]
81+ = immutable.Map .empty[String , ResourceInformation ])(f : Executor => Unit ): Unit = {
82+ var executor : Executor = null
83+ try {
84+ executor = new Executor (executorId, executorHostname, env, userClassPath, isLocal,
85+ uncaughtExceptionHandler, resources)
86+
87+ f(executor)
88+ } finally {
89+ if (executor != null ) {
90+ executor.stop()
91+ }
92+ }
93+ }
94+
6795 test(" SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner" ) {
6896 // mock some objects to make Executor.launchTask() happy
6997 val conf = new SparkConf
@@ -116,10 +144,8 @@ class ExecutorSuite extends SparkFunSuite
116144 }
117145 })
118146
119- var executor : Executor = null
120- try {
121- executor = new Executor (" id" , " localhost" , env, userClassPath = Nil , isLocal = true ,
122- resources = immutable.Map .empty[String , ResourceInformation ])
147+ withExecutor(" id" , " localhost" , env) { executor =>
148+
123149 // the task will be launched in a dedicated worker thread
124150 executor.launchTask(mockExecutorBackend, taskDescription)
125151
@@ -139,11 +165,6 @@ class ExecutorSuite extends SparkFunSuite
139165 assert(executorSuiteHelper.testFailedReason.toErrorString === " TaskKilled (test)" )
140166 assert(executorSuiteHelper.taskState === TaskState .KILLED )
141167 }
142- finally {
143- if (executor != null ) {
144- executor.stop()
145- }
146- }
147168 }
148169
149170 test(" SPARK-19276: Handle FetchFailedExceptions that are hidden by user exceptions" ) {
@@ -255,25 +276,24 @@ class ExecutorSuite extends SparkFunSuite
255276 confs.foreach { case (k, v) => conf.set(k, v) }
256277 val serializer = new JavaSerializer (conf)
257278 val env = createMockEnv(conf, serializer)
258- val executor =
259- new Executor (" id" , " localhost" , SparkEnv .get, userClassPath = Nil , isLocal = true ,
260- resources = immutable.Map .empty[String , ResourceInformation ])
261- val executorClass = classOf [Executor ]
262-
263- // Save all heartbeats sent into an ArrayBuffer for verification
264- val heartbeats = ArrayBuffer [Heartbeat ]()
265- val mockReceiver = mock[RpcEndpointRef ]
266- when(mockReceiver.askSync(any[Heartbeat ], any[RpcTimeout ])(any))
267- .thenAnswer((invocation : InvocationOnMock ) => {
268- val args = invocation.getArguments()
269- heartbeats += args(0 ).asInstanceOf [Heartbeat ]
270- HeartbeatResponse (false )
271- })
272- val receiverRef = executorClass.getDeclaredField(" heartbeatReceiverRef" )
273- receiverRef.setAccessible(true )
274- receiverRef.set(executor, mockReceiver)
279+ withExecutor(" id" , " localhost" , SparkEnv .get) { executor =>
280+ val executorClass = classOf [Executor ]
281+
282+ // Save all heartbeats sent into an ArrayBuffer for verification
283+ val heartbeats = ArrayBuffer [Heartbeat ]()
284+ val mockReceiver = mock[RpcEndpointRef ]
285+ when(mockReceiver.askSync(any[Heartbeat ], any[RpcTimeout ])(any))
286+ .thenAnswer((invocation : InvocationOnMock ) => {
287+ val args = invocation.getArguments()
288+ heartbeats += args(0 ).asInstanceOf [Heartbeat ]
289+ HeartbeatResponse (false )
290+ })
291+ val receiverRef = executorClass.getDeclaredField(" heartbeatReceiverRef" )
292+ receiverRef.setAccessible(true )
293+ receiverRef.set(executor, mockReceiver)
275294
276- f(executor, heartbeats)
295+ f(executor, heartbeats)
296+ }
277297 }
278298
279299 private def heartbeatZeroAccumulatorUpdateTest (dropZeroMetrics : Boolean ): Unit = {
@@ -354,10 +374,7 @@ class ExecutorSuite extends SparkFunSuite
354374 val taskDescription = createResultTaskDescription(serializer, taskBinary, rdd, 0 )
355375
356376 val mockBackend = mock[ExecutorBackend ]
357- var executor : Executor = null
358- try {
359- executor = new Executor (" id" , " localhost" , SparkEnv .get, userClassPath = Nil , isLocal = true ,
360- resources = immutable.Map .empty[String , ResourceInformation ])
377+ withExecutor(" id" , " localhost" , SparkEnv .get) { executor =>
361378 executor.launchTask(mockBackend, taskDescription)
362379
363380 // Ensure that the executor's metricsPoller is polled so that values are recorded for
@@ -368,10 +385,6 @@ class ExecutorSuite extends SparkFunSuite
368385 eventually(timeout(5 .seconds), interval(10 .milliseconds)) {
369386 assert(executor.numRunningTasks === 0 )
370387 }
371- } finally {
372- if (executor != null ) {
373- executor.stop()
374- }
375388 }
376389
377390 // Verify that peak values for task metrics get sent in the TaskResult
@@ -535,12 +548,11 @@ class ExecutorSuite extends SparkFunSuite
535548 poll : Boolean = false ): (TaskFailedReason , UncaughtExceptionHandler ) = {
536549 val mockBackend = mock[ExecutorBackend ]
537550 val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler ]
538- var executor : Executor = null
539551 val timedOut = new AtomicBoolean (false )
540- try {
541- executor = new Executor (" id" , " localhost" , SparkEnv .get, userClassPath = Nil , isLocal = true ,
542- uncaughtExceptionHandler = mockUncaughtExceptionHandler,
543- resources = immutable. Map .empty[ String , ResourceInformation ])
552+
553+ withExecutor (" id" , " localhost" , SparkEnv .get,
554+ uncaughtExceptionHandler = mockUncaughtExceptionHandler) { executor =>
555+
544556 // the task will be launched in a dedicated worker thread
545557 executor.launchTask(mockBackend, taskDescription)
546558 if (killTask) {
@@ -573,11 +585,8 @@ class ExecutorSuite extends SparkFunSuite
573585 assert(executor.numRunningTasks === 0 )
574586 }
575587 assert(! timedOut.get(), " timed out waiting to be ready to kill tasks" )
576- } finally {
577- if (executor != null ) {
578- executor.stop()
579- }
580588 }
589+
581590 val orderedMock = inOrder(mockBackend)
582591 val statusCaptor = ArgumentCaptor .forClass(classOf [ByteBuffer ])
583592 orderedMock.verify(mockBackend)
0 commit comments