diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 085f32b8d83b..2c7f946f33ef 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -929,7 +929,7 @@ private[spark] class Executor( try { val response = heartbeatReceiverRef.askSync[HeartbeatResponse]( message, new RpcTimeout(HEARTBEAT_INTERVAL_MS.millis, EXECUTOR_HEARTBEAT_INTERVAL.key)) - if (response.reregisterBlockManager) { + if (!executorShutdown.get && response.reregisterBlockManager) { logInfo("Told to re-register on heartbeat") env.blockManager.reregister() } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 6b3df6d0c997..665f553d2981 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -270,6 +270,17 @@ class ExecutorSuite extends SparkFunSuite heartbeatZeroAccumulatorUpdateTest(false) } + private def withMockHeartbeatReceiverRef(executor: Executor) + (func: RpcEndpointRef => Unit): Unit = { + val executorClass = classOf[Executor] + val mockReceiverRef = mock[RpcEndpointRef] + val receiverRef = executorClass.getDeclaredField("heartbeatReceiverRef") + receiverRef.setAccessible(true) + receiverRef.set(executor, mockReceiverRef) + + func(mockReceiverRef) + } + private def withHeartbeatExecutor(confs: (String, String)*) (f: (Executor, ArrayBuffer[Heartbeat]) => Unit): Unit = { val conf = new SparkConf @@ -277,22 +288,18 @@ class ExecutorSuite extends SparkFunSuite val serializer = new JavaSerializer(conf) val env = createMockEnv(conf, serializer) withExecutor("id", "localhost", SparkEnv.get) { executor => - val executorClass = classOf[Executor] - - // Save all heartbeats sent into an ArrayBuffer for verification - val heartbeats = ArrayBuffer[Heartbeat]() - val mockReceiver = mock[RpcEndpointRef] - when(mockReceiver.askSync(any[Heartbeat], any[RpcTimeout])(any)) - .thenAnswer((invocation: InvocationOnMock) => { - val args = invocation.getArguments() - heartbeats += args(0).asInstanceOf[Heartbeat] - HeartbeatResponse(false) - }) - val receiverRef = executorClass.getDeclaredField("heartbeatReceiverRef") - receiverRef.setAccessible(true) - receiverRef.set(executor, mockReceiver) - - f(executor, heartbeats) + withMockHeartbeatReceiverRef(executor) { mockReceiverRef => + // Save all heartbeats sent into an ArrayBuffer for verification + val heartbeats = ArrayBuffer[Heartbeat]() + when(mockReceiverRef.askSync(any[Heartbeat], any[RpcTimeout])(any)) + .thenAnswer((invocation: InvocationOnMock) => { + val args = invocation.getArguments() + heartbeats += args(0).asInstanceOf[Heartbeat] + HeartbeatResponse(false) + }) + + f(executor, heartbeats) + } } } @@ -416,6 +423,35 @@ class ExecutorSuite extends SparkFunSuite assert(taskMetrics.getMetricValue("JVMHeapMemory") > 0) } + test("SPARK-34949: do not re-register BlockManager when executor is shutting down") { + val reregisterInvoked = new AtomicBoolean(false) + val mockBlockManager = mock[BlockManager] + when(mockBlockManager.reregister()).thenAnswer { (_: InvocationOnMock) => + reregisterInvoked.getAndSet(true) + } + val conf = new SparkConf(false).setAppName("test").setMaster("local[2]") + val mockEnv = createMockEnv(conf, new JavaSerializer(conf)) + when(mockEnv.blockManager).thenReturn(mockBlockManager) + + withExecutor("id", "localhost", mockEnv) { executor => + withMockHeartbeatReceiverRef(executor) { mockReceiverRef => + when(mockReceiverRef.askSync(any[Heartbeat], any[RpcTimeout])(any)).thenAnswer { + (_: InvocationOnMock) => HeartbeatResponse(reregisterBlockManager = true) + } + val reportHeartbeat = PrivateMethod[Unit](Symbol("reportHeartBeat")) + executor.invokePrivate(reportHeartbeat()) + assert(reregisterInvoked.get(), "BlockManager.reregister should be invoked " + + "on HeartbeatResponse(reregisterBlockManager = true) when executor is not shutting down") + + reregisterInvoked.getAndSet(false) + executor.stop() + executor.invokePrivate(reportHeartbeat()) + assert(!reregisterInvoked.get(), + "BlockManager.reregister should not be invoked when executor is shutting down") + } + } + } + private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { val mockEnv = mock[SparkEnv] val mockRpcEnv = mock[RpcEnv]