Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ private[spark] class Executor(
try {
val response = heartbeatReceiverRef.askSync[HeartbeatResponse](
message, new RpcTimeout(HEARTBEAT_INTERVAL_MS.millis, EXECUTOR_HEARTBEAT_INTERVAL.key))
if (response.reregisterBlockManager) {
if (!executorShutdown.get && response.reregisterBlockManager) {
logInfo("Told to re-register on heartbeat")
env.blockManager.reregister()
}
Expand Down
66 changes: 51 additions & 15 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -270,29 +270,36 @@ class ExecutorSuite extends SparkFunSuite
heartbeatZeroAccumulatorUpdateTest(false)
}

private def withMockHeartbeatReceiverRef(executor: Executor)
(func: RpcEndpointRef => Unit): Unit = {
val executorClass = classOf[Executor]
val mockReceiverRef = mock[RpcEndpointRef]
val receiverRef = executorClass.getDeclaredField("heartbeatReceiverRef")
receiverRef.setAccessible(true)
receiverRef.set(executor, mockReceiverRef)

func(mockReceiverRef)
}

private def withHeartbeatExecutor(confs: (String, String)*)
(f: (Executor, ArrayBuffer[Heartbeat]) => Unit): Unit = {
val conf = new SparkConf
confs.foreach { case (k, v) => conf.set(k, v) }
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
withExecutor("id", "localhost", SparkEnv.get) { executor =>
val executorClass = classOf[Executor]

// Save all heartbeats sent into an ArrayBuffer for verification
val heartbeats = ArrayBuffer[Heartbeat]()
val mockReceiver = mock[RpcEndpointRef]
when(mockReceiver.askSync(any[Heartbeat], any[RpcTimeout])(any))
.thenAnswer((invocation: InvocationOnMock) => {
val args = invocation.getArguments()
heartbeats += args(0).asInstanceOf[Heartbeat]
HeartbeatResponse(false)
})
val receiverRef = executorClass.getDeclaredField("heartbeatReceiverRef")
receiverRef.setAccessible(true)
receiverRef.set(executor, mockReceiver)
withMockHeartbeatReceiverRef(executor) { mockReceiverRef =>
// Save all heartbeats sent into an ArrayBuffer for verification
val heartbeats = ArrayBuffer[Heartbeat]()
when(mockReceiverRef.askSync(any[Heartbeat], any[RpcTimeout])(any))
.thenAnswer((invocation: InvocationOnMock) => {
val args = invocation.getArguments()
heartbeats += args(0).asInstanceOf[Heartbeat]
HeartbeatResponse(false)
})

f(executor, heartbeats)
f(executor, heartbeats)
}
}
}

Expand Down Expand Up @@ -416,6 +423,35 @@ class ExecutorSuite extends SparkFunSuite
assert(taskMetrics.getMetricValue("JVMHeapMemory") > 0)
}

test("SPARK-34949: do not re-register BlockManager when executor is shutting down") {
val reregisterInvoked = new AtomicBoolean(false)
val mockBlockManager = mock[BlockManager]
when(mockBlockManager.reregister()).thenAnswer { (_: InvocationOnMock) =>
reregisterInvoked.getAndSet(true)
}
val conf = new SparkConf(false).setAppName("test").setMaster("local[2]")
val mockEnv = createMockEnv(conf, new JavaSerializer(conf))
when(mockEnv.blockManager).thenReturn(mockBlockManager)

withExecutor("id", "localhost", mockEnv) { executor =>
withMockHeartbeatReceiverRef(executor) { mockReceiverRef =>
when(mockReceiverRef.askSync(any[Heartbeat], any[RpcTimeout])(any)).thenAnswer {
(_: InvocationOnMock) => HeartbeatResponse(reregisterBlockManager = true)
}
val reportHeartbeat = PrivateMethod[Unit](Symbol("reportHeartBeat"))
executor.invokePrivate(reportHeartbeat())
assert(reregisterInvoked.get(), "BlockManager.reregister should be invoked " +
"on HeartbeatResponse(reregisterBlockManager = true) when executor is not shutting down")

reregisterInvoked.getAndSet(false)
executor.stop()
executor.invokePrivate(reportHeartbeat())
assert(!reregisterInvoked.get(),
"BlockManager.reregister should not be invoked when executor is shutting down")
}
}
}

test("SPARK-33587: isFatalError") {
def errorInThreadPool(e: => Throwable): Throwable = {
intercept[Throwable] {
Expand Down