diff --git a/core/src/main/scala/org/apache/spark/CacheRecoveryManager.scala b/core/src/main/scala/org/apache/spark/CacheRecoveryManager.scala new file mode 100644 index 000000000000..6ca871dcb0cd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/CacheRecoveryManager.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future, Promise} +import scala.util.Failure + +import com.google.common.cache.CacheBuilder + +import org.apache.spark.CacheRecoveryManager.{DoneRecovering, KillReason, Timeout} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.DYN_ALLOCATION_CACHE_RECOVERY_TIMEOUT +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.ThreadUtils + +/** + * Responsible for asynchronously replicating all of an executor's cached blocks, and then shutting + * it down. + */ +private class CacheRecoveryManager( + blockManagerMasterEndpoint: RpcEndpointRef, + executorAllocationManager: ExecutorAllocationManager, + conf: SparkConf) + extends Logging { + + private val forceKillAfterS = conf.get(DYN_ALLOCATION_CACHE_RECOVERY_TIMEOUT) + private val threadPool = ThreadUtils.newDaemonCachedThreadPool("cache-recovery-manager-pool") + private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(threadPool) + private val scheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("cache-recovery-shutdown-timers") + private val recoveringExecutors = CacheBuilder.newBuilder() + .expireAfterWrite(forceKillAfterS * 2, TimeUnit.SECONDS) + .build[String, String]() + + /** + * Start the recover cache shutdown process for these executors + * + * @param execIds the executors to start shutting down + * @return a sequence of futures of Unit that will complete once the executor has been killed. + */ + def startCacheRecovery(execIds: Seq[String]): Future[Seq[KillReason]] = { + logDebug(s"Recover cached data before shutting down executors ${execIds.mkString(", ")}.") + val canBeRecovered: Future[Seq[String]] = checkMem(execIds) + + canBeRecovered.flatMap { execIds => + execIds.foreach { execId => recoveringExecutors.put(execId, execId) } + Future.sequence(execIds.map { replicateUntilTimeoutThenKill }) + } + } + + def replicateUntilTimeoutThenKill(execId: String): Future[KillReason] = { + val timeoutFuture = returnAfterTimeout(Timeout, forceKillAfterS) + val replicationFuture = replicateUntilDone(execId) + + Future.firstCompletedOf(List(timeoutFuture, replicationFuture)).andThen { + case scala.util.Success(DoneRecovering) => + logTrace(s"Done recovering blocks on $execId, killing now") + case scala.util.Success(Timeout) => + logWarning(s"Couldn't recover cache on $execId before $forceKillAfterS second timeout") + case Failure(ex) => + logWarning(s"Error recovering cache on $execId", ex) + }.andThen { + case _ => + kill(execId) + } + } + + /** + * Given a list of executors that will be shut down, check if there is enough free memory on the + * rest of the cluster to hold their data. Return a list of just the executors for which there + * will be enough space. Executors are included smallest first. + * + * This is a best guess implementation and it is not guaranteed that all returned executors + * will succeed. For example a block might be too big to fit on any one specific executor. + * + * @param execIds executors which will be shut down + * @return a Seq of the executors we do have room for + */ + private def checkMem(execIds: Seq[String]): Future[Seq[String]] = { + val execsToShutDown = execIds.toSet + // Memory Status is a map of executor Id to a tuple of Max Memory and remaining memory on that + // executor. + val futureMemStatusByBlockManager = + blockManagerMasterEndpoint.ask[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + + val futureMemStatusByExecutor = futureMemStatusByBlockManager.map { memStat => + memStat.map { case (blockManagerId, mem) => blockManagerId.executorId -> mem } + } + + futureMemStatusByExecutor.map { memStatusByExecutor => + val (expiringMemStatus, remainingMemStatus) = memStatusByExecutor.partition { + case (execId, _) => execsToShutDown.contains(execId) + } + val freeMemOnRemaining = remainingMemStatus.values.map(_._2).sum + + // The used mem on each executor sorted from least used mem to greatest + val executorAndUsedMem: Seq[(String, Long)] = + expiringMemStatus.map { case (execId, (maxMem, remainingMem)) => + val usedMem = maxMem - remainingMem + execId -> usedMem + }.toSeq.sortBy { case (_, usedMem) => usedMem } + + executorAndUsedMem + .scan(("start", freeMemOnRemaining)) { + case ((_, freeMem), (execId, usedMem)) => (execId, freeMem - usedMem) + } + .drop(1) + .filter { case (_, freeMem) => freeMem > 0 } + .map(_._1) + } + } + + /** + * Given a value and a timeout in seconds, complete the future with the value when time is up. + * + * @param value The value to be returned after timeout period + * @param seconds the number of seconds to wait + * @return a future that will hold the value given after a timeout + */ + private def returnAfterTimeout[T](value: T, seconds: Long): Future[T] = { + val p = Promise[T]() + val runnable = new Runnable { + def run(): Unit = { p.success(value) } + } + scheduler.schedule(runnable, seconds, TimeUnit.SECONDS) + p.future + } + + /** + * Recover cached RDD blocks off of an executor until there are no more, or until + * there is an error + * + * @param execId the id of the executor to be killed + * @return a Future of Unit that will complete once all blocks have been replicated. + */ + private def replicateUntilDone(execId: String): Future[KillReason] = { + recoverLatestBlock(execId).flatMap { moreBlocks => + if (moreBlocks) replicateUntilDone(execId) else Future.successful(DoneRecovering) + } + } + + /** + * Ask the BlockManagerMaster to replicate the latest cached rdd block off of this executor on to + * a surviving executor, and then remove the block from this executor + * + * @param execId the executor to recover a block from + * @return A future that will hold true if a block was recovered, false otherwise. + */ + private def recoverLatestBlock(execId: String): Future[Boolean] = { + blockManagerMasterEndpoint + .ask[Boolean](RecoverLatestRDDBlock(execId, recoveringExecutors.asMap.keySet.asScala.toSeq)) + } + + /** + * Remove the executor from the list of currently recovering executors and then kill it. + * + * @param execId the id of the executor to be killed + */ + private def kill(execId: String): Unit = { + executorAllocationManager.killExecutors(Seq(execId)) + } + + /** + * Stops all thread pools + */ + def stop(): Unit = { + threadPool.shutdownNow() + scheduler.shutdownNow() + } +} + +private object CacheRecoveryManager { + def apply(eam: ExecutorAllocationManager, conf: SparkConf): CacheRecoveryManager = { + val bmme = SparkEnv.get.blockManager.master.driverEndpoint + new CacheRecoveryManager(bmme, eam, conf) + } + + sealed trait KillReason + case object Timeout extends KillReason + case object DoneRecovering extends KillReason +} diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index 63d87b4cd385..c93d110e0341 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -85,4 +85,10 @@ private[spark] trait ExecutorAllocationClient { countFailures = false) killedExecutors.nonEmpty && killedExecutors(0).equals(executorId) } + + /** + * Mark these executors as pending to be removed + * @param executorIds Executors that will be removed and should not accept new work. + */ + def markPendingToRemove(executorIds: Seq[String]): Unit } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 189d91333c04..04a4039e944e 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -26,7 +26,7 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS} +import org.apache.spark.internal.config._ import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMaster @@ -90,6 +90,8 @@ private[spark] class ExecutorAllocationManager( import ExecutorAllocationManager._ + private var cacheRecoveryManager: CacheRecoveryManager = _ + // Lower and upper bounds on the number of executors. private val minNumExecutors = conf.get(DYN_ALLOCATION_MIN_EXECUTORS) private val maxNumExecutors = conf.get(DYN_ALLOCATION_MAX_EXECUTORS) @@ -110,6 +112,9 @@ private[spark] class ExecutorAllocationManager( private val cachedExecutorIdleTimeoutS = conf.getTimeAsSeconds( "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${Integer.MAX_VALUE}s") + // whether or not to try and save cached data when executors are deallocated + private val recoverCachedData = conf.get(DYN_ALLOCATION_CACHE_RECOVERY) + // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) @@ -212,6 +217,11 @@ private[spark] class ExecutorAllocationManager( if (tasksPerExecutor == 0) { throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.") } + + if (recoverCachedData && cachedExecutorIdleTimeoutS == Integer.MAX_VALUE) { + throw new SparkException(s"spark.dynamicAllocation.cachedExecutorIdleTimeout must be set if" + + s"${DYN_ALLOCATION_CACHE_RECOVERY.key} is true.") + } } /** @@ -243,12 +253,19 @@ private[spark] class ExecutorAllocationManager( executor.scheduleWithFixedDelay(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) + + if (recoverCachedData) { + cacheRecoveryManager = CacheRecoveryManager(this, conf) + } } /** * Stop the allocation manager. */ def stop(): Unit = { + if (cacheRecoveryManager != null) { + cacheRecoveryManager.stop() + } executor.shutdown() executor.awaitTermination(10, TimeUnit.SECONDS) } @@ -432,68 +449,59 @@ private[spark] class ExecutorAllocationManager( /** * Request the cluster manager to remove the given executors. - * Returns the list of executors which are removed. + * + * @param executors the ids of the executors to be removed + * + * @return Unit */ - private def removeExecutors(executors: Seq[String]): Seq[String] = synchronized { - val executorIdsToBeRemoved = new ArrayBuffer[String] - + private def removeExecutors(executors: Seq[String]): Unit = synchronized { logInfo("Request to remove executorIds: " + executors.mkString(", ")) - val numExistingExecutors = allocationManager.executorIds.size - executorsPendingToRemove.size - - var newExecutorTotal = numExistingExecutors - executors.foreach { executorIdToBeRemoved => - if (newExecutorTotal - 1 < minNumExecutors) { - logDebug(s"Not removing idle executor $executorIdToBeRemoved because there are only " + - s"$newExecutorTotal executor(s) left (minimum number of executor limit $minNumExecutors)") - } else if (newExecutorTotal - 1 < numExecutorsTarget) { - logDebug(s"Not removing idle executor $executorIdToBeRemoved because there are only " + - s"$newExecutorTotal executor(s) left (number of executor target $numExecutorsTarget)") - } else if (canBeKilled(executorIdToBeRemoved)) { - executorIdsToBeRemoved += executorIdToBeRemoved - newExecutorTotal -= 1 + val numExistingExecs = allocationManager.executorIds.size - executorsPendingToRemove.size + val execCountFloor = math.max(minNumExecutors, numExecutorsTarget) + val (executorIdsToBeRemoved, dontRemove) = executors + .filter(canBeKilled) + .splitAt(numExistingExecs - execCountFloor) + + if (log.isDebugEnabled()) { + dontRemove.foreach { execId => + logDebug(s"Not removing idle executor $execId because it " + + s"would put us below the minimum limit of $minNumExecutors executors" + + s"or number of target executors $numExecutorsTarget") } } if (executorIdsToBeRemoved.isEmpty) { - return Seq.empty[String] + Seq.empty[String] + } else if (testing) { + recordExecutorKill(executorIdsToBeRemoved) + } else if (recoverCachedData) { + client.markPendingToRemove(executorIdsToBeRemoved) + recordExecutorKill(executorIdsToBeRemoved) + cacheRecoveryManager.startCacheRecovery(executorIdsToBeRemoved) + } else { + val killed = killExecutors(executorIdsToBeRemoved) + recordExecutorKill(killed) } + } - // Send a request to the backend to kill this executor(s) - val executorsRemoved = if (testing) { - executorIdsToBeRemoved - } else { - // We don't want to change our target number of executors, because we already did that - // when the task backlog decreased. - client.killExecutors(executorIdsToBeRemoved, adjustTargetNumExecutors = false, + def killExecutors(executorIds: Seq[String]): Seq[String] = { + logDebug(s"Starting kill process for ${executorIds.mkString(", ")}") + val result = client.killExecutors(executorIds, adjustTargetNumExecutors = false, countFailures = false, force = false) - } - // [SPARK-21834] killExecutors api reduces the target number of executors. - // So we need to update the target with desired value. - client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) - // reset the newExecutorTotal to the existing number of executors - newExecutorTotal = numExistingExecutors - if (testing || executorsRemoved.nonEmpty) { - executorsRemoved.foreach { removedExecutorId => - newExecutorTotal -= 1 - logInfo(s"Removing executor $removedExecutorId because it has been idle for " + - s"$executorIdleTimeoutS seconds (new desired total will be $newExecutorTotal)") - executorsPendingToRemove.add(removedExecutorId) - } - executorsRemoved - } else { + if (result.isEmpty) { logWarning(s"Unable to reach the cluster manager to kill executor/s " + - s"${executorIdsToBeRemoved.mkString(",")} or no executor eligible to kill!") - Seq.empty[String] + s"${executorIds.mkString(",")} or no executor eligible to kill!") } + result } - /** - * Request the cluster manager to remove the given executor. - * Return whether the request is acknowledged. - */ - private def removeExecutor(executorId: String): Boolean = synchronized { - val executorsRemoved = removeExecutors(Seq(executorId)) - executorsRemoved.nonEmpty && executorsRemoved(0) == executorId + private def recordExecutorKill(executorsRemoved: Seq[String]): Unit = synchronized { + // [SPARK-21834] killExecutors api reduces the target number of executors. + // So we need to update the target with desired value. + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) + executorsPendingToRemove ++= executorsRemoved + logInfo(s"Removing executors (${executorsRemoved.mkString(", ")}) because they have been idle" + + s"for $executorIdleTimeoutS seconds") } /** diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 407545aa4a47..3eb145264e93 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -130,6 +130,16 @@ package object config { .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("3s") + private[spark] val DYN_ALLOCATION_CACHE_RECOVERY = + ConfigBuilder("spark.dynamicAllocation.cacheRecovery.enabled") + .booleanConf + .createWithDefault(false) + + private[spark] val DYN_ALLOCATION_CACHE_RECOVERY_TIMEOUT = + ConfigBuilder("spark.dynamicAllocation.cacheRecovery.timeout") + .timeConf(TimeUnit.SECONDS) + .createWithDefault(120) + private[spark] val SHUFFLE_SERVICE_ENABLED = ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5627a557a12f..2b05aa09ce98 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -21,6 +21,7 @@ import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.concurrent.Future @@ -80,9 +81,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Executors we have requested the cluster manager to kill that have not died yet; maps // the executor ID to whether it was explicitly killed by the driver (and thus shouldn't - // be considered an app-related failure). + // be considered an app-related failure), is draining resources before being killed, or died + // on its own. + private sealed trait PendingStatus + private case object KilledByDriver extends PendingStatus + private case object Died extends PendingStatus + private case object Draining extends PendingStatus @GuardedBy("CoarseGrainedSchedulerBackend.this") - private val executorsPendingToRemove = new HashMap[String, Boolean] + private val executorsPendingToRemove = new HashMap[String, PendingStatus] // A map to store hostname with its possible task number running on it @GuardedBy("CoarseGrainedSchedulerBackend.this") @@ -320,15 +326,16 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case Some(executorInfo) => // This must be synchronized because variables mutated // in this block are read when requesting executors - val killed = CoarseGrainedSchedulerBackend.this.synchronized { + val removeStatus = CoarseGrainedSchedulerBackend.this.synchronized { addressToExecutorId -= executorInfo.executorAddress executorDataMap -= executorId executorsPendingLossReason -= executorId - executorsPendingToRemove.remove(executorId).getOrElse(false) + executorsPendingToRemove.remove(executorId).getOrElse(Died) } totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) - scheduler.executorLost(executorId, if (killed) ExecutorKilled else reason) + scheduler.executorLost(executorId, + if (removeStatus == KilledByDriver) ExecutorKilled else reason) listenerBus.post( SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString)) case None => @@ -602,16 +609,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val response = synchronized { val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains) - unknownExecutors.foreach { id => - logWarning(s"Executor to kill $id does not exist!") - } + unknownExecutors.foreach { id => logWarning(s"Executor to kill $id does not exist!") } // If an executor is already pending to be removed, do not kill it again (SPARK-9795) // If this executor is busy, do not kill it unless we are told to force kill it (SPARK-9552) val executorsToKill = knownExecutors - .filter { id => !executorsPendingToRemove.contains(id) } + .filter { id => + !executorsPendingToRemove.contains(id) || executorsPendingToRemove(id) == Draining + } .filter { id => force || !scheduler.isExecutorBusy(id) } - executorsToKill.foreach { id => executorsPendingToRemove(id) = !countFailures } + + executorsToKill.foreach { id => + executorsPendingToRemove(id) = if (!adjustTargetNumExecutors) Died else KilledByDriver + } logInfo(s"Actual list of executor(s) to be killed is ${executorsToKill.mkString(", ")}") @@ -637,18 +647,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp Future.successful(true) } - val killExecutors: Boolean => Future[Boolean] = - if (!executorsToKill.isEmpty) { - _ => doKillExecutors(executorsToKill) - } else { - _ => Future.successful(false) - } - - val killResponse = adjustTotalExecutors.flatMap(killExecutors)(ThreadUtils.sameThread) + val killResponse = if (executorsToKill.nonEmpty) { + adjustTotalExecutors.flatMap { _ => + doKillExecutors(executorsToKill) + }(ThreadUtils.sameThread) + } else { + Future.successful(false) + } - killResponse.flatMap(killSuccessful => - Future.successful (if (killSuccessful) executorsToKill else Seq.empty[String]) - )(ThreadUtils.sameThread) + killResponse.map { successful => + if (successful) executorsToKill else Seq.empty + }(ThreadUtils.sameThread) } defaultAskTimeout.awaitResult(response) @@ -678,6 +687,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } protected def fetchHadoopDelegationTokens(): Option[Array[Byte]] = { None } + + override def markPendingToRemove(executorIds: Seq[String]): Unit = synchronized { + logDebug(s"marking executors (${executorIds.mkString(", ")}) pending to remove") + executorIds.foreach { id => executorsPendingToRemove.getOrElseUpdate(id, Draining) } + } } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0276a4dc422..e5883fc41e41 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1272,11 +1272,13 @@ private[spark] class BlockManager( * * @param blockId blockId being replicate * @param existingReplicas existing block managers that have a replica + * @param excluding block managers that won't be replicated to * @param maxReplicas maximum replicas needed */ def replicateBlock( blockId: BlockId, existingReplicas: Set[BlockManagerId], + excluding: Set[BlockManagerId], maxReplicas: Int): Unit = { logInfo(s"Using $blockManagerId to pro-actively replicate $blockId") blockInfoManager.lockForReading(blockId).foreach { info => @@ -1291,7 +1293,7 @@ private[spark] class BlockManager( // this way, we won't try to replicate to a missing executor with a stale reference getPeers(forceFetch = true) try { - replicate(blockId, data, storageLevel, info.classTag, existingReplicas) + replicate(blockId, data, storageLevel, info.classTag, existingReplicas, excluding) } finally { logDebug(s"Releasing lock for $blockId") releaseLockAndDispose(blockId, data) @@ -1308,7 +1310,8 @@ private[spark] class BlockManager( data: BlockData, level: StorageLevel, classTag: ClassTag[_], - existingReplicas: Set[BlockManagerId] = Set.empty): Unit = { + existingReplicas: Set[BlockManagerId] = Set.empty, + excludedBlockManagers: Set[BlockManagerId] = Set.empty): Unit = { val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1) val tLevel = StorageLevel( @@ -1325,7 +1328,9 @@ private[spark] class BlockManager( val peersFailedToReplicateTo = mutable.HashSet.empty[BlockManagerId] var numFailures = 0 - val initialPeers = getPeers(false).filterNot(existingReplicas.contains) + val initialPeers = getPeers(false) + .filterNot(existingReplicas.contains) + .filterNot(excludedBlockManagers.contains) var peersForReplication = blockReplicationPolicy.prioritize( blockManagerId, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 8e8f7d197c9e..5d8e089e21f2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -124,6 +124,9 @@ class BlockManagerMasterEndpoint( removeExecutor(execId) context.reply(true) + case RecoverLatestRDDBlock(execId, excluded) => + recoverLatestRDDBlock(execId, excluded, context) + case StopBlockManagerMaster => context.reply(true) stop() @@ -151,6 +154,7 @@ class BlockManagerMasterEndpoint( // Find all blocks for the given RDD, remove the block from both blockLocations and // the blockManagerInfo that is tracking the blocks. val blocks = blockLocations.asScala.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + blocks.foreach { blockId => val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) @@ -236,7 +240,8 @@ class BlockManagerMasterEndpoint( val candidateBMId = blockLocations(i) blockManagerInfo.get(candidateBMId).foreach { bm => val remainingLocations = locations.toSeq.filter(bm => bm != candidateBMId) - val replicateMsg = ReplicateBlock(blockId, remainingLocations, maxReplicas) + val replicateMsg = + ReplicateBlock(blockId, remainingLocations, Seq.empty[BlockManagerId], maxReplicas) bm.slaveEndpoint.ask[Boolean](replicateMsg) } } @@ -252,6 +257,44 @@ class BlockManagerMasterEndpoint( blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) } + private def recoverLatestRDDBlock( + execId: String, + excludeExecutors: Seq[String], + context: RpcCallContext): Unit = { + logDebug(s"Replicating first cached block on $execId") + val excluded = excludeExecutors.flatMap(blockManagerIdByExecutor.get) + val response: Option[Future[Boolean]] = for { + blockManagerId <- blockManagerIdByExecutor.get(execId) + info <- blockManagerInfo.get(blockManagerId) + blocks = info.cachedBlocks.collect { case r: RDDBlockId => r } + // As a heuristic, prioritize replicating the latest rdd. If this succeeds, + // CacheRecoveryManager will try to replicate the remaining rdds. + firstBlock <- if (blocks.isEmpty) None else Some(blocks.maxBy(_.rddId)) + replicaSet <- blockLocations.asScala.get(firstBlock) + // Add 2 to force this block to be replicated to one new executor. + maxReps = replicaSet.size + 2 + isMem = info.getStatus(firstBlock).exists { _.storageLevel.useMemory } + } yield { + if (isMem) { + val msg = ReplicateBlock(firstBlock, replicaSet.toSeq, excluded, maxReps) + info.slaveEndpoint.ask[Boolean](msg) + .flatMap { _ => + logTrace(s"Replicated block $firstBlock on executor $execId") + replicaSet -= blockManagerId + updateBlockInfo(blockManagerId, firstBlock, StorageLevel.NONE, 0, 0) + info.slaveEndpoint.ask[Boolean](RemoveBlock(firstBlock)) + } + } else { + logTrace(s"Did not replicate block $firstBlock on executor $execId") + replicaSet -= blockManagerId + updateBlockInfo(blockManagerId, firstBlock, StorageLevel.NONE, 0, 0) + info.slaveEndpoint.ask[Boolean](RemoveBlock(firstBlock)) + } + } + + response.getOrElse(Future.successful(false)).foreach(context.reply) + } + /** * Return true if the driver knows about the given block manager. Otherwise, return false, * indicating that the block manager should re-register. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 1bbe7a5b3950..76f325f7ebcc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -32,8 +32,11 @@ private[spark] object BlockManagerMessages { // blocks that the master knows about. case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave - // Replicate blocks that were lost due to executor failure - case class ReplicateBlock(blockId: BlockId, replicas: Seq[BlockManagerId], maxReplicas: Int) + case class ReplicateBlock( + blockId: BlockId, + replicas: Seq[BlockManagerId], + excluding: Seq[BlockManagerId], + maxReplicas: Int) extends ToBlockManagerSlave // Remove all blocks belonging to a specific RDD. @@ -123,4 +126,7 @@ private[spark] object BlockManagerMessages { case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster case class HasCachedBlocks(executorId: String) extends ToBlockManagerMaster + + case class RecoverLatestRDDBlock(executorId: String, excludingExecs: Seq[String]) + extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 742cf4fe393f..de65121fd3bd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -75,9 +75,9 @@ class BlockManagerSlaveEndpoint( case TriggerThreadDump => context.reply(Utils.getThreadDump()) - case ReplicateBlock(blockId, replicas, maxReplicas) => - context.reply(blockManager.replicateBlock(blockId, replicas.toSet, maxReplicas)) - + case ReplicateBlock(blockId, replicas, excluding, maxReplicas) => + blockManager.replicateBlock(blockId, replicas.toSet, excluding.toSet, maxReplicas) + context.reply(true) } private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { diff --git a/core/src/test/scala/org/apache/spark/CacheRecoveryManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheRecoveryManagerSuite.scala new file mode 100644 index 000000000000..07e123c834b6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/CacheRecoveryManagerSuite.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.{Future, Promise} +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration.Duration +import scala.reflect.ClassTag + +import org.mockito.Mockito._ +import org.scalatest.Matchers +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark.CacheRecoveryManager._ +import org.apache.spark.internal.config.DYN_ALLOCATION_CACHE_RECOVERY_TIMEOUT +import org.apache.spark.network.util.ByteUnit +import org.apache.spark.rpc._ +import org.apache.spark.storage.{BlockId, BlockManagerId, RDDBlockId} +import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.ThreadUtils + +class CacheRecoveryManagerSuite + extends SparkFunSuite with MockitoSugar with Matchers { + + val oneGB: Long = ByteUnit.GiB.toBytes(1).toLong + + val plentyOfMem = Map( + BlockManagerId("1", "host", 12, None) -> ((oneGB, oneGB)), + BlockManagerId("2", "host", 12, None) -> ((oneGB, oneGB)), + BlockManagerId("3", "host", 12, None) -> ((oneGB, oneGB))) + + test("replicate blocks until empty and then kill executor") { + val conf = new SparkConf() + val eam = mock[ExecutorAllocationManager] + val blocks = Seq(RDDBlockId(1, 1), RDDBlockId(2, 1)) + val bmme = FakeBMM(blocks.iterator, plentyOfMem) + val bmmeRef = DummyRef(bmme) + val cacheRecoveryManager = new CacheRecoveryManager(bmmeRef, eam, conf) + when(eam.killExecutors(Seq("1"))).thenReturn(Seq("1")) + + try { + val future = cacheRecoveryManager.startCacheRecovery(Seq("1")) + val results = ThreadUtils.awaitResult(future, Duration(3, TimeUnit.SECONDS)) + results.head shouldBe DoneRecovering + verify(eam).killExecutors(Seq("1")) + bmme.replicated.get("1").get shouldBe 2 + } finally { + cacheRecoveryManager.stop() + } + } + + test("kill executor if it takes too long to replicate") { + val conf = new SparkConf().set(DYN_ALLOCATION_CACHE_RECOVERY_TIMEOUT.key, "1s") + val eam = mock[ExecutorAllocationManager] + val blocks = Set(RDDBlockId(1, 1), RDDBlockId(2, 1), RDDBlockId(3, 1), RDDBlockId(4, 1)) + val bmme = FakeBMM(blocks.iterator, plentyOfMem, pauseIndefinitely = true) + val bmmeRef = DummyRef(bmme) + val cacheRecoveryManager = new CacheRecoveryManager(bmmeRef, eam, conf) + + try { + val future = cacheRecoveryManager.startCacheRecovery(Seq("1")) + val results = ThreadUtils.awaitResult(future, Duration(3, TimeUnit.SECONDS)) + results.head shouldBe Timeout + verify(eam, times(1)).killExecutors(Seq("1")) + bmme.replicated.get("1").get shouldBe 1 + } finally { + cacheRecoveryManager.stop() + } + } + + test("shutdown timer will get cancelled if replication finishes") { + val conf = new SparkConf().set(DYN_ALLOCATION_CACHE_RECOVERY_TIMEOUT.key, "1s") + val eam = mock[ExecutorAllocationManager] + val blocks = Set(RDDBlockId(1, 1)) + val bmme = FakeBMM(blocks.iterator, plentyOfMem) + val bmmeRef = DummyRef(bmme) + val cacheRecoveryManager = new CacheRecoveryManager(bmmeRef, eam, conf) + + try { + val future = cacheRecoveryManager.startCacheRecovery(Seq("1")) + val results = ThreadUtils.awaitResult(future, Duration(3, TimeUnit.SECONDS)) + + results.head shouldBe DoneRecovering + verify(eam, times(1)).killExecutors(Seq("1")) + } finally { + cacheRecoveryManager.stop() + } + } + + test("blocks won't replicate if we are running out of space") { + val conf = new SparkConf() + val eam = mock[ExecutorAllocationManager] + val blocks = Seq(RDDBlockId(1, 1), RDDBlockId(1, 1), RDDBlockId(1, 1), RDDBlockId(1, 1)) + val memStatus = Map(BlockManagerId("1", "host", 12, None) -> ((2L, 1L)), + BlockManagerId("2", "host", 12, None) -> ((3L, 1L)), + BlockManagerId("3", "host", 12, None) -> ((4L, 1L)), + BlockManagerId("4", "host", 12, None) -> ((4L, 4L))) + val bmme = FakeBMM(blocks.iterator, memStatus) + val bmmeRef = DummyRef(bmme) + val cacheRecoveryManager = new CacheRecoveryManager(bmmeRef, eam, conf) + + try { + val future = cacheRecoveryManager.startCacheRecovery(Seq("1", "2", "3")) + val results = ThreadUtils.awaitResult(future, Duration(3, TimeUnit.SECONDS)) + results.foreach { _ shouldBe DoneRecovering } + bmme.replicated.size shouldBe 2 + } finally { + cacheRecoveryManager.stop() + } + } + + test("blocks won't replicate if we are stopping all executors") { + val conf = new SparkConf() + val eam = mock[ExecutorAllocationManager] + val blocks = Seq(RDDBlockId(1, 1), RDDBlockId(1, 1), RDDBlockId(1, 1), RDDBlockId(1, 1)) + val memStatus = Map(BlockManagerId("1", "host", 12, None) -> ((2L, 1L)), + BlockManagerId("2", "host", 12, None) -> ((2L, 1L)), + BlockManagerId("3", "host", 12, None) -> ((2L, 1L)), + BlockManagerId("4", "host", 12, None) -> ((2L, 1L))) + val bmme = FakeBMM(blocks.iterator, memStatus) + val bmmeRef = DummyRef(bmme) + val cacheRecoveryManager = new CacheRecoveryManager(bmmeRef, eam, conf) + + try { + val future = cacheRecoveryManager.startCacheRecovery(Seq("1", "2", "3", "4")) + val results = ThreadUtils.awaitResult(future, Duration(3, TimeUnit.SECONDS)) + results.foreach { _ shouldBe DoneRecovering } + bmme.replicated.size shouldBe 0 + } finally { + cacheRecoveryManager.stop() + } + } +} + +private case class FakeBMM( + blocks: Iterator[BlockId], + memStatus: Map[BlockManagerId, (Long, Long)], + sizeOfBlock: Long = 1, + pauseIndefinitely: Boolean = false + ) extends ThreadSafeRpcEndpoint { + + val rpcEnv: RpcEnv = null + val replicated: ConcurrentHashMap[String, AtomicInteger] = + new ConcurrentHashMap[String, AtomicInteger]() + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RecoverLatestRDDBlock(execId, _) => + val future = Future { + if (blocks.hasNext) { + blocks.next() + replicated.putIfAbsent(execId, new AtomicInteger(0)) + replicated.get(execId).incrementAndGet() + true + } else { + false + } + } + if (!pauseIndefinitely) { future.foreach(context.reply) } + case GetMemoryStatus => context.reply(memStatus) + } +} + +// Turns an RpcEndpoint into RpcEndpointRef by calling receive and reply directly +private case class DummyRef(endpoint: RpcEndpoint) extends RpcEndpointRef(new SparkConf()) { + def address: RpcAddress = null + def name: String = null + def send(message: Any): Unit = endpoint.receive(message) + def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { + val context = new DummyRpcCallContext[T] + endpoint.receiveAndReply(context)(message) + context.result + } +} + +// saves values you put in context.reply +private class DummyRpcCallContext[T] extends RpcCallContext { + val promise: Promise[T] = Promise[T]() + def result: Future[T] = promise.future + def reply(response: Any): Unit = promise.success(response.asInstanceOf[T]) + def sendFailure(e: Throwable): Unit = () + def senderAddress: RpcAddress = null +} + diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 9807d1269e3d..7ae7b2f3b518 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import scala.collection.mutable +import scala.concurrent.Future import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.{mock, never, verify, when} @@ -285,20 +286,20 @@ class ExecutorAllocationManagerSuite // Keep removing until the limit is reached assert(executorsPendingToRemove(manager).isEmpty) - assert(removeExecutor(manager, "1")) + removeExecutor(manager, "1") assert(executorsPendingToRemove(manager).size === 1) assert(executorsPendingToRemove(manager).contains("1")) - assert(removeExecutor(manager, "2")) - assert(removeExecutor(manager, "3")) + removeExecutor(manager, "2") + removeExecutor(manager, "3") assert(executorsPendingToRemove(manager).size === 3) assert(executorsPendingToRemove(manager).contains("2")) assert(executorsPendingToRemove(manager).contains("3")) - assert(!removeExecutor(manager, "100")) // remove non-existent executors - assert(!removeExecutor(manager, "101")) + removeExecutor(manager, "100") // remove non-existent executors + removeExecutor(manager, "101") assert(executorsPendingToRemove(manager).size === 3) - assert(removeExecutor(manager, "4")) - assert(removeExecutor(manager, "5")) - assert(!removeExecutor(manager, "6")) // reached the limit of 5 + removeExecutor(manager, "4") + removeExecutor(manager, "5") + removeExecutor(manager, "6") // reached the limit of 5 assert(executorsPendingToRemove(manager).size === 5) assert(executorsPendingToRemove(manager).contains("4")) assert(executorsPendingToRemove(manager).contains("5")) @@ -322,9 +323,9 @@ class ExecutorAllocationManagerSuite // Try removing again // This should still fail because the number pending + running is still at the limit - assert(!removeExecutor(manager, "7")) + removeExecutor(manager, "7") assert(executorsPendingToRemove(manager).isEmpty) - assert(!removeExecutor(manager, "8")) + removeExecutor(manager, "8") assert(executorsPendingToRemove(manager).isEmpty) } @@ -335,19 +336,19 @@ class ExecutorAllocationManagerSuite // Keep removing until the limit is reached assert(executorsPendingToRemove(manager).isEmpty) - assert(removeExecutors(manager, Seq("1")) === Seq("1")) + removeExecutors(manager, Seq("1")) assert(executorsPendingToRemove(manager).size === 1) assert(executorsPendingToRemove(manager).contains("1")) - assert(removeExecutors(manager, Seq("2", "3")) === Seq("2", "3")) + removeExecutors(manager, Seq("2", "3")) assert(executorsPendingToRemove(manager).size === 3) assert(executorsPendingToRemove(manager).contains("2")) assert(executorsPendingToRemove(manager).contains("3")) - assert(!removeExecutor(manager, "100")) // remove non-existent executors - assert(removeExecutors(manager, Seq("101", "102")) !== Seq("101", "102")) + removeExecutor(manager, "100") // remove non-existent executors + removeExecutors(manager, Seq("101", "102")) assert(executorsPendingToRemove(manager).size === 3) - assert(removeExecutor(manager, "4")) - assert(removeExecutors(manager, Seq("5")) === Seq("5")) - assert(!removeExecutor(manager, "6")) // reached the limit of 5 + removeExecutor(manager, "4") + removeExecutors(manager, Seq("5")) + removeExecutor(manager, "6") // reached the limit of 5 assert(executorsPendingToRemove(manager).size === 5) assert(executorsPendingToRemove(manager).contains("4")) assert(executorsPendingToRemove(manager).contains("5")) @@ -371,9 +372,9 @@ class ExecutorAllocationManagerSuite // Try removing again // This should still fail because the number pending + running is still at the limit - assert(!removeExecutor(manager, "7")) + removeExecutor(manager, "7") assert(executorsPendingToRemove(manager).isEmpty) - assert(removeExecutors(manager, Seq("8")) !== Seq("8")) + removeExecutors(manager, Seq("8")) assert(executorsPendingToRemove(manager).isEmpty) } @@ -391,7 +392,8 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager).size === 8) assert(numExecutorsTarget(manager) === 8) assert(maxNumExecutorsNeeded(manager) == 8) - assert(!removeExecutor(manager, "1")) // won't work since numExecutorsTarget == numExecutors + removeExecutor(manager, "1") // won't work since numExecutorsTarget == numExecutors + assert(!executorsPendingToRemove(manager).contains("1")) // Remove executors when numExecutorsTarget is lower than current number of executors (1 to 3).map { i => createTaskInfo(i, i, s"$i") }.foreach { info => @@ -401,8 +403,11 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager).size === 8) assert(numExecutorsTarget(manager) === 5) assert(maxNumExecutorsNeeded(manager) == 5) - assert(removeExecutor(manager, "1")) - assert(removeExecutors(manager, Seq("2", "3"))=== Seq("2", "3")) + removeExecutor(manager, "1") + assert(executorsPendingToRemove(manager).contains("1")) + removeExecutors(manager, Seq("2", "3")) + assert(executorsPendingToRemove(manager).contains("2")) + assert(executorsPendingToRemove(manager).contains("3")) onExecutorRemoved(manager, "1") onExecutorRemoved(manager, "2") onExecutorRemoved(manager, "3") @@ -413,7 +418,8 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager).size === 5) assert(numExecutorsTarget(manager) === 5) assert(maxNumExecutorsNeeded(manager) == 4) - assert(!removeExecutor(manager, "4")) // lower limit + removeExecutor(manager, "4") // lower limit + assert(!executorsPendingToRemove(manager).contains("4")) assert(addExecutors(manager) === 0) // upper limit } @@ -436,10 +442,12 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager).size === 8) assert(numExecutorsTarget(manager) === 8) - // Remove when numTargetExecutors is equal to the current number of executors - assert(!removeExecutor(manager, "1")) - assert(removeExecutors(manager, Seq("2", "3")) !== Seq("2", "3")) + removeExecutor(manager, "1") + assert(!executorsPendingToRemove(manager).contains("1")) + removeExecutors(manager, Seq("2", "3")) + assert(!executorsPendingToRemove(manager).contains("2")) + assert(!executorsPendingToRemove(manager).contains("3")) // Remove until limit onExecutorAdded(manager, "9") @@ -449,10 +457,17 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager).size === 12) assert(numExecutorsTarget(manager) === 8) - assert(removeExecutor(manager, "1")) - assert(removeExecutors(manager, Seq("2", "3", "4")) === Seq("2", "3", "4")) - assert(!removeExecutor(manager, "5")) // lower limit reached - assert(!removeExecutor(manager, "6")) + removeExecutor(manager, "1") + assert(executorsPendingToRemove(manager).contains("1")) + removeExecutors(manager, Seq("2", "3", "4")) + assert(executorsPendingToRemove(manager).contains("2")) + assert(executorsPendingToRemove(manager).contains("3")) + assert(executorsPendingToRemove(manager).contains("4")) + removeExecutor(manager, "5") // lower limit reached + assert(!executorsPendingToRemove(manager).contains("5")) + removeExecutor(manager, "6") + assert(!executorsPendingToRemove(manager).contains("6")) + removeExecutor(manager, "5") // lower limit reached onExecutorRemoved(manager, "1") onExecutorRemoved(manager, "2") onExecutorRemoved(manager, "3") @@ -460,7 +475,8 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager).size === 8) // Add until limit - assert(!removeExecutor(manager, "7")) // still at lower limit + removeExecutor(manager, "7") // still at lower limit + assert(!executorsPendingToRemove(manager).contains("7")) assert((manager, Seq("8")) !== Seq("8")) onExecutorAdded(manager, "13") onExecutorAdded(manager, "14") @@ -469,8 +485,12 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager).size === 12) // Remove succeeds again, now that we are no longer at the lower limit - assert(removeExecutors(manager, Seq("5", "6", "7")) === Seq("5", "6", "7")) - assert(removeExecutor(manager, "8")) + removeExecutors(manager, Seq("5", "6", "7")) + assert(executorsPendingToRemove(manager).contains("5")) + assert(executorsPendingToRemove(manager).contains("6")) + assert(executorsPendingToRemove(manager).contains("7")) + removeExecutor(manager, "8") + assert(executorsPendingToRemove(manager).contains("8")) assert(executorIds(manager).size === 12) onExecutorRemoved(manager, "5") onExecutorRemoved(manager, "6") @@ -1238,11 +1258,11 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { manager invokePrivate _updateAndSyncNumExecutorsTarget(0L) } - private def removeExecutor(manager: ExecutorAllocationManager, id: String): Boolean = { - manager invokePrivate _removeExecutor(id) + private def removeExecutor(manager: ExecutorAllocationManager, id: String): Unit = { + manager invokePrivate _removeExecutors(Seq(id)) } - private def removeExecutors(manager: ExecutorAllocationManager, ids: Seq[String]): Seq[String] = { + private def removeExecutors(manager: ExecutorAllocationManager, ids: Seq[String]): Unit = { manager invokePrivate _removeExecutors(ids) } @@ -1346,4 +1366,6 @@ private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend override def killExecutorsOnHost(host: String): Boolean = { false } + + override def markPendingToRemove(executorIds: Seq[String]): Unit = Seq.empty } diff --git a/core/src/test/scala/org/apache/spark/scheduler/CacheRecoveryIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CacheRecoveryIntegrationSuite.scala new file mode 100644 index 000000000000..06393cba4347 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/CacheRecoveryIntegrationSuite.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.util.Try + +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers} +import org.scalatest.concurrent.Eventually +import org.scalatest.time.{Seconds, Span} + +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite, TestUtils} +import org.apache.spark.internal.config._ +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.rdd.RDD +import org.apache.spark.storage._ + +/** + * This is an integration test for the cache recovery feature using a local spark cluster. It + * extends the unit tests in CacheRecoveryManagerSuite which mocks a lot of cluster infrastructure. + */ +class CacheRecoveryIntegrationSuite extends SparkFunSuite + with Matchers + with BeforeAndAfterEach + with BeforeAndAfterAll + with Eventually { + + private var conf: SparkConf = makeBaseConf() + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 4) + private val rpcHandler = new ExternalShuffleBlockHandler(transportConf, null) + private val transportContext = new TransportContext(transportConf, rpcHandler) + private val shuffleService = transportContext.createServer() + private var sc: SparkContext = _ + + private def makeBaseConf() = new SparkConf() + .setAppName("test") + .setMaster("local-cluster[4, 1, 512]") + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.executorIdleTimeout", "1s") // always + .set("spark.dynamicAllocation.cachedExecutorIdleTimeout", "1s") + .set(EXECUTOR_MEMORY.key, "512m") + .set(SHUFFLE_SERVICE_ENABLED.key, "true") + .set(DYN_ALLOCATION_CACHE_RECOVERY.key, "true") + .set(DYN_ALLOCATION_CACHE_RECOVERY_TIMEOUT.key, "500s") + .set(EXECUTOR_INSTANCES.key, "1") + .set(DYN_ALLOCATION_INITIAL_EXECUTORS.key, "4") + .set(DYN_ALLOCATION_MIN_EXECUTORS.key, "3") + + override def beforeEach(): Unit = { + conf = makeBaseConf() + conf.set("spark.shuffle.service.port", shuffleService.getPort.toString) + } + + override def afterEach(): Unit = { + sc.stop() + conf = null + } + + override def afterAll(): Unit = { + shuffleService.close() + } + + private def getLocations( + sc: SparkContext, + rdd: RDD[_]): Map[BlockId, Map[BlockManagerId, BlockStatus]] = { + import scala.collection.breakOut + val blockIds: Array[BlockId] = rdd.partitions.map(p => RDDBlockId(rdd.id, p.index)) + blockIds.map { id => + id -> Try(sc.env.blockManager.master.getBlockStatus(id)).getOrElse(Map.empty) + }(breakOut) + } + + test("cached data is replicated before dynamic de-allocation") { + sc = new SparkContext(conf) + TestUtils.waitUntilExecutorsUp(sc, 4, 60000) + + val rdd = sc.parallelize(1 to 1000, 4).map(_ * 4).cache() + rdd.reduce(_ + _) shouldBe 2002000 + sc.getExecutorIds().size shouldBe 4 + getLocations(sc, rdd).forall { case (_, map) => map.nonEmpty } shouldBe true + + eventually(timeout(Span(5, Seconds)), interval(Span(1, Seconds))) { + sc.getExecutorIds().size shouldBe 3 + getLocations(sc, rdd).forall { case (_, map) => map.nonEmpty } shouldBe true + } + } + + test("dont fail if a bunch of executors are shut down at once") { + conf.set("spark.dynamicAllocation.minExecutors", "1") + sc = new SparkContext(conf) + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + + val rdd = sc.parallelize(1 to 1000, 4).map(_ * 4).cache() + rdd.reduce(_ + _) shouldBe 2002000 + sc.getExecutorIds().size shouldBe 4 + getLocations(sc, rdd).forall { case (_, map) => map.nonEmpty } shouldBe true + + eventually(timeout(Span(5, Seconds)), interval(Span(1, Seconds))) { + sc.getExecutorIds().size shouldBe 1 + getLocations(sc, rdd).forall { case (_, map) => map.nonEmpty } shouldBe true + } + } + + test("executors should not accept new work while replicating away data before deallocation") { + conf.set("spark.dynamicAllocation.minExecutors", "1") + + sc = new SparkContext(conf) + TestUtils.waitUntilExecutorsUp(sc, 4, 60000) + + val rdd = sc.parallelize(1 to 100000, 4).map(_ * 4L).cache() // cache on all 4 executors + rdd.reduce(_ + _) shouldBe 20000200000L // realize the cache + + // wait until executors are de-allocated + eventually(timeout(Span(4, Seconds)), interval(Span(1, Seconds))) { + sc.getExecutorIds().size shouldBe 1 + } + + val rdd2 = sc.parallelize(1 to 100000, 4).map(_ * 4L).cache() // should be created on 1 exe + rdd2.reduce(_ + _) shouldBe 20000200000L + + val executorIds = for { + maps <- getLocations(sc, rdd2).values + blockManagerId <- maps.keys + } yield blockManagerId.executorId + + executorIds.toSet.size shouldBe 1 + } + + test("throw error if cache recovery is enabled and cachedExecutor timeout is not set") { + conf.remove("spark.dynamicAllocation.cachedExecutorIdleTimeout") + a [SparkException] should be thrownBy new SparkContext(conf) + } +} diff --git a/docs/configuration.md b/docs/configuration.md index 2eb6a77434ea..b4db994fe96c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1815,6 +1815,26 @@ Apart from these, the following properties are also available, and may be useful description. +
spark.dynamicAllocation.recoverCachedDatafalsespark.dynamicAllocation.cachedExecutorIdleTimeout is set, then idle executors with
+ cached data will attempt to replicate their cached data to other remaining executors before they
+ are shut down. If there is not enough memory on the cluster then the executor will be immediately
+ shut down.
+ spark.dynamicAllocation.recoverCachedData.timeout120s