Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
// Asynchronously kill the executor to avoid blocking the current thread
killExecutorThread.submit(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
sc.killExecutor(executorId)
// Note: we want to get an executor back after expiring this one,
// so do not simply call `sc.killExecutor` here (SPARK-8119)
sc.killAndReplaceExecutor(executorId)
}
})
}
Expand Down
40 changes: 38 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/**
* :: DeveloperApi ::
* Request that the cluster manager kill the specified executors.
*
* Note: This is an indication to the cluster manager that the application wishes to adjust
* its resource usage downwards. If the application wishes to replace the executors it kills
* through this method with new ones, it should follow up explicitly with a call to
* {{SparkContext#requestExecutors}}.
*
* This is currently only supported in YARN mode. Return whether the request is received.
*/
@DeveloperApi
Expand All @@ -1436,12 +1442,42 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli

/**
* :: DeveloperApi ::
* Request that cluster manager the kill the specified executor.
* This is currently only supported in Yarn mode. Return whether the request is received.
* Request that the cluster manager kill the specified executor.
*
* Note: This is an indication to the cluster manager that the application wishes to adjust
* its resource usage downwards. If the application wishes to replace the executor it kills
* through this method with a new one, it should follow up explicitly with a call to
* {{SparkContext#requestExecutors}}.
*
* This is currently only supported in YARN mode. Return whether the request is received.
*/
@DeveloperApi
override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId)

/**
* Request that the cluster manager kill the specified executor without adjusting the
* application resource requirements.
*
* The effect is that a new executor will be launched in place of the one killed by
* this request. This assumes the cluster manager will automatically and eventually
* fulfill all missing application resource requests.
*
* Note: The replace is by no means guaranteed; another application on the same cluster
* can steal the window of opportunity and acquire this application's resources in the
* mean time.
*
* This is currently only supported in YARN mode. Return whether the request is received.
*/
private[spark] def killAndReplaceExecutor(executorId: String): Boolean = {
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
b.killExecutors(Seq(executorId), replace = true)
case _ =>
logWarning("Killing executors is only supported in coarse-grained mode")
false
}
}

/** The version of Spark on which this application is running. */
def version: String = SPARK_VERSION

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,26 +371,36 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp

/**
* Request that the cluster manager kill the specified executors.
* Return whether the kill request is acknowledged.
* @return whether the kill request is acknowledged.
*/
final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized {
killExecutors(executorIds, replace = false)
}

/**
* Request that the cluster manager kill the specified executors.
*
* @param executorIds identifiers of executors to kill
* @param replace whether to replace the killed executors with new ones
* @return whether the kill request is acknowledged.
*/
final def killExecutors(executorIds: Seq[String], replace: Boolean): Boolean = synchronized {
logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}")
val filteredExecutorIds = new ArrayBuffer[String]
executorIds.foreach { id =>
if (executorDataMap.contains(id)) {
filteredExecutorIds += id
} else {
logWarning(s"Executor to kill $id does not exist!")
}
val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains)
unknownExecutors.foreach { id =>
logWarning(s"Executor to kill $id does not exist!")
}

// If we do not wish to replace the executors we kill, sync the target number of executors
// with the cluster manager to avoid allocating new ones. When computing the new target,
// take into account executors that are pending to be added or removed.
if (!replace) {
doRequestTotalExecutors(numExistingExecutors + numPendingExecutors
- executorsPendingToRemove.size - knownExecutors.size)
}
// Killing executors means effectively that we want less executors than before, so also update
// the target number of executors to avoid having the backend allocate new ones.
val newTotal = (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size
- filteredExecutorIds.size)
doRequestTotalExecutors(newTotal)

executorsPendingToRemove ++= filteredExecutorIds
doKillExecutors(filteredExecutorIds)
executorsPendingToRemove ++= knownExecutors
doKillExecutors(knownExecutors)
}

/**
Expand Down
147 changes: 128 additions & 19 deletions core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark

import java.util.concurrent.{ExecutorService, TimeUnit}

import scala.collection.mutable
import scala.language.postfixOps

import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
Expand All @@ -25,11 +28,16 @@ import org.mockito.Matchers
import org.mockito.Matchers._

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv, RpcEndpointRef}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.ManualClock

/**
* A test suite for the heartbeating behavior between the driver and the executors.
*/
class HeartbeatReceiverSuite
extends SparkFunSuite
with BeforeAndAfterEach
Expand All @@ -40,23 +48,40 @@ class HeartbeatReceiverSuite
private val executorId2 = "executor-2"

// Shared state that must be reset before and after each test
private var scheduler: TaskScheduler = null
private var scheduler: TaskSchedulerImpl = null
private var heartbeatReceiver: HeartbeatReceiver = null
private var heartbeatReceiverRef: RpcEndpointRef = null
private var heartbeatReceiverClock: ManualClock = null

// Helper private method accessors for HeartbeatReceiver
private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen)
private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs)
private val _killExecutorThread = PrivateMethod[ExecutorService]('killExecutorThread)

/**
* Before each test, set up the SparkContext and a custom [[HeartbeatReceiver]]
* that uses a manual clock.
*/
override def beforeEach(): Unit = {
sc = spy(new SparkContext("local[2]", "test"))
scheduler = mock(classOf[TaskScheduler])
val conf = new SparkConf()
.setMaster("local[2]")
.setAppName("test")
.set("spark.dynamicAllocation.testing", "true")
sc = spy(new SparkContext(conf))
scheduler = mock(classOf[TaskSchedulerImpl])
when(sc.taskScheduler).thenReturn(scheduler)
when(scheduler.sc).thenReturn(sc)
heartbeatReceiverClock = new ManualClock
heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock)
heartbeatReceiverRef = sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver)
when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
}

/**
* After each test, clean up all state and stop the [[SparkContext]].
*/
override def afterEach(): Unit = {
resetSparkContext()
super.afterEach()
scheduler = null
heartbeatReceiver = null
heartbeatReceiverRef = null
Expand All @@ -75,23 +100,23 @@ class HeartbeatReceiverSuite
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
triggerHeartbeat(executorId1, executorShouldReregister = false)
triggerHeartbeat(executorId2, executorShouldReregister = false)
val trackedExecutors = executorLastSeen(heartbeatReceiver)
val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen())
assert(trackedExecutors.size === 2)
assert(trackedExecutors.contains(executorId1))
assert(trackedExecutors.contains(executorId2))
}

test("reregister if scheduler is not ready yet") {
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
// Task scheduler not set in HeartbeatReceiver
// Task scheduler is not set yet in HeartbeatReceiver, so executors should reregister
triggerHeartbeat(executorId1, executorShouldReregister = true)
}

test("reregister if heartbeat from unregistered executor") {
heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
// Received heartbeat from unknown receiver, so we ask it to re-register
// Received heartbeat from unknown executor, so we ask it to re-register
triggerHeartbeat(executorId1, executorShouldReregister = true)
assert(executorLastSeen(heartbeatReceiver).isEmpty)
assert(heartbeatReceiver.invokePrivate(_executorLastSeen()).isEmpty)
}

test("reregister if heartbeat from removed executor") {
Expand All @@ -104,14 +129,14 @@ class HeartbeatReceiverSuite
// A heartbeat from the second executor should require reregistering
triggerHeartbeat(executorId1, executorShouldReregister = false)
triggerHeartbeat(executorId2, executorShouldReregister = true)
val trackedExecutors = executorLastSeen(heartbeatReceiver)
val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen())
assert(trackedExecutors.size === 1)
assert(trackedExecutors.contains(executorId1))
assert(!trackedExecutors.contains(executorId2))
}

test("expire dead hosts") {
val executorTimeout = executorTimeoutMs(heartbeatReceiver)
val executorTimeout = heartbeatReceiver.invokePrivate(_executorTimeoutMs())
heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
Expand All @@ -124,12 +149,61 @@ class HeartbeatReceiverSuite
heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts)
// Only the second executor should be expired as a dead host
verify(scheduler).executorLost(Matchers.eq(executorId2), any())
val trackedExecutors = executorLastSeen(heartbeatReceiver)
val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen())
assert(trackedExecutors.size === 1)
assert(trackedExecutors.contains(executorId1))
assert(!trackedExecutors.contains(executorId2))
}

test("expire dead hosts should kill executors with replacement (SPARK-8119)") {
// Set up a fake backend and cluster manager to simulate killing executors
val rpcEnv = sc.env.rpcEnv
val fakeClusterManager = new FakeClusterManager(rpcEnv)
val fakeClusterManagerRef = rpcEnv.setupEndpoint("fake-cm", fakeClusterManager)
val fakeSchedulerBackend = new FakeSchedulerBackend(scheduler, rpcEnv, fakeClusterManagerRef)
when(sc.schedulerBackend).thenReturn(fakeSchedulerBackend)

// Register fake executors with our fake scheduler backend
// This is necessary because the backend refuses to kill executors it does not know about
fakeSchedulerBackend.start()
val dummyExecutorEndpoint1 = new FakeExecutorEndpoint(rpcEnv)
val dummyExecutorEndpoint2 = new FakeExecutorEndpoint(rpcEnv)
val dummyExecutorEndpointRef1 = rpcEnv.setupEndpoint("fake-executor-1", dummyExecutorEndpoint1)
val dummyExecutorEndpointRef2 = rpcEnv.setupEndpoint("fake-executor-2", dummyExecutorEndpoint2)
fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type](
RegisterExecutor(executorId1, dummyExecutorEndpointRef1, "dummy:4040", 0, Map.empty))
fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type](
RegisterExecutor(executorId2, dummyExecutorEndpointRef2, "dummy:4040", 0, Map.empty))
heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
triggerHeartbeat(executorId1, executorShouldReregister = false)
triggerHeartbeat(executorId2, executorShouldReregister = false)

// Adjust the target number of executors on the cluster manager side
assert(fakeClusterManager.getTargetNumExecutors === 0)
sc.requestTotalExecutors(2)
assert(fakeClusterManager.getTargetNumExecutors === 2)
assert(fakeClusterManager.getExecutorIdsToKill.isEmpty)

// Expire the executors. This should trigger our fake backend to kill the executors.
// Since the kill request is sent to the cluster manager asynchronously, we need to block
// on the kill thread to ensure that the cluster manager actually received our requests.
// Here we use a timeout of O(seconds), but in practice this whole test takes O(10ms).
val executorTimeout = heartbeatReceiver.invokePrivate(_executorTimeoutMs())
heartbeatReceiverClock.advance(executorTimeout * 2)
heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts)
val killThread = heartbeatReceiver.invokePrivate(_killExecutorThread())
killThread.shutdown() // needed for awaitTermination
killThread.awaitTermination(10L, TimeUnit.SECONDS)

// The target number of executors should not change! Otherwise, having an expired
// executor means we permanently adjust the target number downwards until we
// explicitly request new executors. For more detail, see SPARK-8119.
assert(fakeClusterManager.getTargetNumExecutors === 2)
assert(fakeClusterManager.getExecutorIdsToKill === Set(executorId1, executorId2))
}

/** Manually send a heartbeat and return the response. */
private def triggerHeartbeat(
executorId: String,
Expand All @@ -148,14 +222,49 @@ class HeartbeatReceiverSuite
}
}

// Helper methods to access private fields in HeartbeatReceiver
private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen)
private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs)
private def executorLastSeen(receiver: HeartbeatReceiver): collection.Map[String, Long] = {
receiver invokePrivate _executorLastSeen()
}

// TODO: use these classes to add end-to-end tests for dynamic allocation!

/**
* Dummy RPC endpoint to simulate executors.
*/
private class FakeExecutorEndpoint(override val rpcEnv: RpcEnv) extends RpcEndpoint

/**
* Dummy scheduler backend to simulate executor allocation requests to the cluster manager.
*/
private class FakeSchedulerBackend(
scheduler: TaskSchedulerImpl,
rpcEnv: RpcEnv,
clusterManagerEndpoint: RpcEndpointRef)
extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) {

protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
clusterManagerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal))
}
private def executorTimeoutMs(receiver: HeartbeatReceiver): Long = {
receiver invokePrivate _executorTimeoutMs()

protected override def doKillExecutors(executorIds: Seq[String]): Boolean = {
clusterManagerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds))
}
}

/**
* Dummy cluster manager to simulate responses to executor allocation requests.
*/
private class FakeClusterManager(override val rpcEnv: RpcEnv) extends RpcEndpoint {
private var targetNumExecutors = 0
private val executorIdsToKill = new mutable.HashSet[String]

def getTargetNumExecutors: Int = targetNumExecutors
def getExecutorIdsToKill: Set[String] = executorIdsToKill.toSet

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RequestExecutors(requestedTotal) =>
targetNumExecutors = requestedTotal
context.reply(true)
case KillExecutors(executorIds) =>
executorIdsToKill ++= executorIds
context.reply(true)
}
}