diff --git a/core/src/main/scala/org/apache/spark/HostState.scala b/core/src/main/scala/org/apache/spark/HostState.scala new file mode 100644 index 000000000000..17b374c3fac2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/HostState.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.hadoop.yarn.api.records.NodeState + +private[spark] object HostState extends Enumeration { + + type HostState = Value + + val New, Running, Unhealthy, Decommissioning, Decommissioned, Lost, Rebooted = Value + + def fromYarnState(state: String): Option[HostState] = { + HostState.values.find(_.toString.toUpperCase == state) + } + + def toYarnState(state: HostState): Option[String] = { + NodeState.values.find(_.name == state.toString.toUpperCase).map(_.name) + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 9495cd2835f9..84edcff707d4 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -154,6 +154,16 @@ package object config { ConfigBuilder("spark.blacklist.application.fetchFailure.enabled") .booleanConf .createWithDefault(false) + + private[spark] val BLACKLIST_DECOMMISSIONING_ENABLED = + ConfigBuilder("spark.blacklist.decommissioning.enabled") + .booleanConf + .createWithDefault(false) + + private[spark] val BLACKLIST_DECOMMISSIONING_TIMEOUT_CONF = + ConfigBuilder("spark.blacklist.decommissioning.timeout") + .timeConf(TimeUnit.MILLISECONDS) + .createOptional // End blacklist confs private[spark] val UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE = diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index cd8e61d6d020..7bc3db8ce1bb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -61,7 +61,13 @@ private[scheduler] class BlacklistTracker ( private val MAX_FAILURES_PER_EXEC = conf.get(config.MAX_FAILURES_PER_EXEC) private val MAX_FAILED_EXEC_PER_NODE = conf.get(config.MAX_FAILED_EXEC_PER_NODE) val BLACKLIST_TIMEOUT_MILLIS = BlacklistTracker.getBlacklistTimeout(conf) - private val BLACKLIST_FETCH_FAILURE_ENABLED = conf.get(config.BLACKLIST_FETCH_FAILURE_ENABLED) + val BLACKLIST_DECOMMISSIONING_TIMEOUT_MILLIS = + BlacklistTracker.getBlacklistDecommissioningTimeout(conf) + private val TASK_BLACKLISTING_ENABLED = BlacklistTracker.isTaskExecutionBlacklistingEnabled(conf) + private val DECOMMISSIONING_BLACKLISTING_ENABLED = + BlacklistTracker.isDecommissioningBlacklistingEnabled(conf) + private val BLACKLIST_FETCH_FAILURE_ENABLED = + BlacklistTracker.isFetchFailureBlacklistingEnabled(conf) /** * A map from executorId to information on task failures. Tracks the time of each task failure, @@ -89,13 +95,13 @@ private[scheduler] class BlacklistTracker ( * successive blacklisted executors on one node. Nonetheless, it will not grow too large because * there cannot be many blacklisted executors on one node, before we stop requesting more * executors on that node, and we clean up the list of blacklisted executors once an executor has - * been blacklisted for BLACKLIST_TIMEOUT_MILLIS. + * been blacklisted for its configured blacklisting timeout. */ val nodeToBlacklistedExecs = new HashMap[String, HashSet[String]]() /** - * Un-blacklists executors and nodes that have been blacklisted for at least - * BLACKLIST_TIMEOUT_MILLIS + * Un-blacklists executors and nodes that have been blacklisted for at least its configured + * blacklisting timeout */ def applyBlacklistTimeout(): Unit = { val now = clock.getTimeMillis() @@ -118,16 +124,9 @@ private[scheduler] class BlacklistTracker ( } } val nodesToUnblacklist = nodeIdToBlacklistExpiryTime.filter(_._2 < now).keys - if (nodesToUnblacklist.nonEmpty) { - // Un-blacklist any nodes that have been blacklisted longer than the blacklist timeout. - logInfo(s"Removing nodes $nodesToUnblacklist from blacklist because the blacklist " + - s"has timed out") - nodesToUnblacklist.foreach { node => - nodeIdToBlacklistExpiryTime.remove(node) - listenerBus.post(SparkListenerNodeUnblacklisted(now, node)) - } - _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) - } + .map(node => (node, BlacklistTimedOut, Some(now))) + // Un-blacklist any nodes that have been blacklisted longer than the blacklist timeout. + removeNodesFromBlacklist(nodesToUnblacklist) updateNextExpiryTime() } } @@ -190,14 +189,8 @@ private[scheduler] class BlacklistTracker ( val expiryTimeForNewBlacklists = now + BLACKLIST_TIMEOUT_MILLIS if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { - if (!nodeIdToBlacklistExpiryTime.contains(host)) { - logInfo(s"blacklisting node $host due to fetch failure of external shuffle service") - - nodeIdToBlacklistExpiryTime.put(host, expiryTimeForNewBlacklists) - listenerBus.post(SparkListenerNodeBlacklisted(now, host, 1)) - _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) + if (addNodeToBlacklist(host, FetchFailure(host), now)) { killExecutorsOnBlacklistedNode(host) - updateNextExpiryTime() } } else if (!executorIdToBlacklistStatus.contains(exec)) { logInfo(s"Blacklisting executor $exec due to fetch failure") @@ -249,21 +242,93 @@ private[scheduler] class BlacklistTracker ( // node, and potentially put the entire node into a blacklist as well. val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(node, HashSet[String]()) blacklistedExecsOnNode += exec - // If the node is already in the blacklist, we avoid adding it again with a later expiry - // time. - if (blacklistedExecsOnNode.size >= MAX_FAILED_EXEC_PER_NODE && - !nodeIdToBlacklistExpiryTime.contains(node)) { - logInfo(s"Blacklisting node $node because it has ${blacklistedExecsOnNode.size} " + - s"executors blacklisted: ${blacklistedExecsOnNode}") - nodeIdToBlacklistExpiryTime.put(node, expiryTimeForNewBlacklists) - listenerBus.post(SparkListenerNodeBlacklisted(now, node, blacklistedExecsOnNode.size)) - _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) - killExecutorsOnBlacklistedNode(node) + if (blacklistedExecsOnNode.size >= MAX_FAILED_EXEC_PER_NODE) { + val blacklistSucceeded = addNodeToBlacklist(node, + ExecutorFailures(Set(blacklistedExecsOnNode.toList: _*)), now) + if (blacklistSucceeded) { + killExecutorsOnBlacklistedNode(node) + } } } } } + /** + * Add nodes to Blacklist, with a specific timeout depending upon the reason. If the node is + * already in the Blacklist, it is not added again. + * @param node Node to be blacklisted + * @param reason Reason for blacklisting the node + * @param time Optional start time on which to compute the blacklist expiry time + * @return boolean value indicating whether node was added to blacklist or not + */ + def addNodeToBlacklist(node: String, reason: NodeBlacklistReason, + time: Long = clock.getTimeMillis()): Boolean = { + // If the node is already in the blacklist, we avoid adding it again with a later expiry time. + if (!isNodeBlacklisted(node)) { + val blacklistExpiryTimeOpt = reason match { + case NodeDecommissioning if DECOMMISSIONING_BLACKLISTING_ENABLED => + val expiryTime = time + BLACKLIST_DECOMMISSIONING_TIMEOUT_MILLIS + logInfo(s"Blacklisting node $node with timeout $expiryTime ms because ${reason.message}") + Some(expiryTime) + + case ExecutorFailures(blacklistedExecutors) if TASK_BLACKLISTING_ENABLED => + val expiryTime = time + BLACKLIST_TIMEOUT_MILLIS + logInfo(s"Blacklisting node $node with timeout $expiryTime ms because it " + + s"has ${blacklistedExecutors.size} executors blacklisted: ${blacklistedExecutors}") + Some(expiryTime) + + case FetchFailure(host) if BLACKLIST_FETCH_FAILURE_ENABLED => + val expiryTime = time + BLACKLIST_TIMEOUT_MILLIS + logInfo(s"Blacklisting node $host due to fetch failure of external shuffle service") + Some(expiryTime) + + case _ => None + } + + blacklistExpiryTimeOpt.fold(false) { blacklistExpiryTime => + blacklistNodeHelper(node, blacklistExpiryTime) + listenerBus.post(SparkListenerNodeBlacklisted(time, node, reason)) + updateNextExpiryTime() + true + } + } + else { + false + } + } + + private def blacklistNodeHelper(node: String, blacklistExpiryTimeout: Long): Unit = { + nodeIdToBlacklistExpiryTime.put(node, blacklistExpiryTimeout) + _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) + } + + private def unblacklistNodesHelper(nodes: Iterable[String]): Unit = { + nodeIdToBlacklistExpiryTime --= nodes + _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) + } + + /** + * @param nodesToRemove List of nodes to unblacklist, with there reason for unblacklisting + * and an optional time to be passed to Spark Listener indicating the + * time of unblacklist. + */ + def removeNodesFromBlacklist(nodesToRemove: Iterable[(String, NodeUnblacklistReason, + Option[Long])]): Unit = { + if (nodesToRemove.nonEmpty) { + val blacklistNodesToRemove = nodesToRemove.filter{ case (node, reason, _) => + (reason == BlacklistTimedOut || + (reason == NodeRunning && DECOMMISSIONING_BLACKLISTING_ENABLED)) && + isNodeBlacklisted(node) + } + unblacklistNodesHelper(blacklistNodesToRemove.map(_._1)) + blacklistNodesToRemove.foreach(node => { + logInfo(s"Removing node $node from blacklist because ${node._2.message}") + listenerBus.post(SparkListenerNodeUnblacklisted( + node._3.getOrElse(clock.getTimeMillis()), node._1, node._2)) + }) + } + } + def isExecutorBlacklisted(executorId: String): Boolean = { executorIdToBlacklistStatus.contains(executorId) } @@ -373,15 +438,39 @@ private[scheduler] class BlacklistTracker ( private[scheduler] object BlacklistTracker extends Logging { private val DEFAULT_TIMEOUT = "1h" + private val DEFAULT_DECOMMISSIONING_TIMEOUT = "1h" /** - * Returns true if the blacklist is enabled, based on checking the configuration in the following - * order: + * Returns true if the task execution blacklist, fetch failure blacklist, + * or decommission blacklisting are enabled + */ + def isBlacklistEnabled(conf: SparkConf): Boolean = { + isFetchFailureBlacklistingEnabled(conf) || isDecommissioningBlacklistingEnabled(conf) || + isTaskExecutionBlacklistingEnabled(conf) + } + + /** + * Returns true if the fetch failure blacklisting is enabled + */ + def isFetchFailureBlacklistingEnabled(conf: SparkConf): Boolean = { + conf.get(config.BLACKLIST_FETCH_FAILURE_ENABLED) + } + + /** + * Returns true if the decommission blacklisting is enabled + */ + def isDecommissioningBlacklistingEnabled(conf: SparkConf): Boolean = { + conf.get(config.BLACKLIST_DECOMMISSIONING_ENABLED) + } + + /** + * Returns true if the task execution blacklist is enabled, based on checking the configuration + * in the following order: * 1. Is it specifically enabled or disabled? * 2. Is it enabled via the legacy timeout conf? * 3. Default is off */ - def isBlacklistEnabled(conf: SparkConf): Boolean = { + def isTaskExecutionBlacklistingEnabled(conf: SparkConf): Boolean = { conf.get(config.BLACKLIST_ENABLED) match { case Some(enabled) => enabled @@ -409,6 +498,11 @@ private[scheduler] object BlacklistTracker extends Logging { } } + def getBlacklistDecommissioningTimeout(conf: SparkConf): Long = { + conf.get(config.BLACKLIST_DECOMMISSIONING_TIMEOUT_CONF) + .getOrElse(Utils.timeStringAsMs(DEFAULT_DECOMMISSIONING_TIMEOUT)) + } + /** * Verify that blacklist configurations are consistent; if not, throw an exception. Should only * be called if blacklisting is enabled. @@ -449,6 +543,12 @@ private[scheduler] object BlacklistTracker extends Logging { } } + val blacklistDecommissioningTimeout = getBlacklistDecommissioningTimeout(conf) + if (blacklistDecommissioningTimeout <= 0) { + mustBePos(config.BLACKLIST_DECOMMISSIONING_TIMEOUT_CONF.key, + blacklistDecommissioningTimeout.toString) + } + val maxTaskFailures = conf.get(config.MAX_TASK_FAILURES) val maxNodeAttempts = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE) @@ -458,7 +558,9 @@ private[scheduler] object BlacklistTracker extends Logging { s"( = ${maxTaskFailures} ). Though blacklisting is enabled, with this configuration, " + s"Spark will not be robust to one bad node. Decrease " + s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key}, increase ${config.MAX_TASK_FAILURES.key}, " + - s"or disable blacklisting with ${config.BLACKLIST_ENABLED.key}") + s"or disable blacklisting with ${config.BLACKLIST_ENABLED.key}, " + + s"${config.BLACKLIST_DECOMMISSIONING_ENABLED.key} " + + s"and ${config.BLACKLIST_FETCH_FAILURE_ENABLED.key}") } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/NodeBlacklistReason.scala b/core/src/main/scala/org/apache/spark/scheduler/NodeBlacklistReason.scala new file mode 100644 index 000000000000..53fba245e0ee --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/NodeBlacklistReason.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.annotation.DeveloperApi + +/** + * Represents an explanation for a Node being blacklisted for task scheduling + */ +@DeveloperApi +private[spark] sealed trait NodeBlacklistReason extends Serializable { + def message: String +} + +@DeveloperApi +private[spark] case class ExecutorFailures(blacklistedExecutors: Set[String]) + extends NodeBlacklistReason { + override def message: String = "Maximum number of executor failures allowed on Node exceeded." +} + +@DeveloperApi +private[spark] case object NodeDecommissioning extends NodeBlacklistReason { + override def message: String = "Node is being decommissioned by Cluster Manager." +} + +@DeveloperApi +private[spark] case class FetchFailure(host: String) extends NodeBlacklistReason { + override def message: String = s"Fetch failure for host $host" +} + diff --git a/core/src/main/scala/org/apache/spark/scheduler/NodeUnblacklistReason.scala b/core/src/main/scala/org/apache/spark/scheduler/NodeUnblacklistReason.scala new file mode 100644 index 000000000000..b388ff997fa3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/NodeUnblacklistReason.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.annotation.DeveloperApi + +/** + * Represents an explanation for a Node being unblacklisted for task scheduling. + */ +@DeveloperApi +private[spark] sealed trait NodeUnblacklistReason extends Serializable { + def message: String +} + +@DeveloperApi +private[spark] object BlacklistTimedOut extends NodeUnblacklistReason { + override def message: String = "Blacklist timeout has reached." +} + +@DeveloperApi +private[spark] object NodeRunning extends NodeUnblacklistReason { + override def message: String = "Node is active and back to Running state." +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 59f89a82a1da..5fe9cd76de38 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -125,11 +125,11 @@ case class SparkListenerExecutorUnblacklisted(time: Long, executorId: String) case class SparkListenerNodeBlacklisted( time: Long, hostId: String, - executorFailures: Int) + reason: NodeBlacklistReason) extends SparkListenerEvent @DeveloperApi -case class SparkListenerNodeUnblacklisted(time: Long, hostId: String) +case class SparkListenerNodeUnblacklisted(time: Long, hostId: String, reason: NodeUnblacklistReason) extends SparkListenerEvent @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 737b38363114..8c23f7cd7769 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -656,6 +656,14 @@ private[spark] class TaskSchedulerImpl( blacklistTrackerOpt.map(_.nodeBlacklist()).getOrElse(scala.collection.immutable.Set()) } + def blacklistExecutorsOnHost(host: String, reason: NodeBlacklistReason): Unit = synchronized { + blacklistTrackerOpt.foreach(_.addNodeToBlacklist(host, reason)) + } + + def unblacklistExecutorsOnHost(host: String, reason: NodeUnblacklistReason): Unit = synchronized { + blacklistTrackerOpt.foreach(_.removeNodesFromBlacklist(List((host, reason, None)))) + } + // By default, rack is unknown def getRackForHost(value: String): Option[String] = None diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index c2f817858473..0e75032b3de5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -96,8 +96,10 @@ private[spark] class TaskSetManager( private var calculatedTasks = 0 private[scheduler] val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = { - blacklistTracker.map { _ => - new TaskSetBlacklist(conf, stageId, clock) + if (blacklistTracker.isDefined && BlacklistTracker.isTaskExecutionBlacklistingEnabled(conf)) { + Some(new TaskSetBlacklist(conf, stageId, clock)) + } else { + None } } @@ -519,7 +521,7 @@ private[spark] class TaskSetManager( private def maybeFinishTaskSet() { if (isZombie && runningTasks == 0) { sched.taskSetFinished(this) - if (tasksSuccessful == numTasks) { + if (taskSetBlacklistHelperOpt.isDefined && tasksSuccessful == numTasks) { blacklistTracker.foreach(_.updateBlacklistForSuccessfulTaskSet( taskSet.stageId, taskSet.stageAttemptId, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 5d65731dfc30..79fabd380288 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer +import org.apache.spark.HostState.HostState import org.apache.spark.TaskState.TaskState import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.ExecutorLossReason @@ -110,6 +111,9 @@ private[spark] object CoarseGrainedClusterMessages { nodeBlacklist: Set[String]) extends CoarseGrainedClusterMessage + case class HostStatusUpdate(host: String, hostState: HostState) + extends CoarseGrainedClusterMessage + // Check if an executor was force-killed but for a reason unrelated to the running tasks. // This could be the case if the executor is preempted, for instance. case class GetExecutorLossReason(executorId: String) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index a0ef20977930..d8857a19f9c8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -26,7 +26,7 @@ import scala.concurrent.Future import org.apache.hadoop.security.UserGroupInformation -import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState} +import org.apache.spark.{ExecutorAllocationClient, HostState, SparkEnv, SparkException, TaskState} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.Logging @@ -482,6 +482,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp }(ThreadUtils.sameThread) } + private[scheduler] def handleUpdatedHostState(host: String, + hostState: HostState.HostState): Unit = { + hostState match { + case HostState.Decommissioning => + scheduler.blacklistExecutorsOnHost(host, NodeDecommissioning) + + case HostState.Running => + scheduler.unblacklistExecutorsOnHost(host, NodeRunning) + + case HostState.Decommissioned | HostState.Lost => + // TODO: Take action when a node is Decommissioned or Lost + + case _ => + } + } + def sufficientResourcesRegistered(): Boolean = true override def isReady(): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 8406826a228d..9925ad3c18ab 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -100,6 +100,10 @@ private[spark] object JsonProtocol { executorMetricsUpdateToJson(metricsUpdate) case blockUpdated: SparkListenerBlockUpdated => throw new MatchError(blockUpdated) // TODO(ekl) implement this + case nodeBlacklisted: SparkListenerNodeBlacklisted => + nodeBlacklistedToJson(nodeBlacklisted) + case nodeUnblacklisted: SparkListenerNodeUnblacklisted => + nodeUnblacklistedToJson(nodeUnblacklisted) case _ => parse(mapper.writeValueAsString(event)) } } @@ -246,6 +250,20 @@ private[spark] object JsonProtocol { }) } + def nodeBlacklistedToJson(nodeBlacklisted: SparkListenerNodeBlacklisted): JValue = { + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.nodeBlacklisted) ~ + ("hostId" -> nodeBlacklisted.hostId) ~ + ("time" -> nodeBlacklisted.time) ~ + ("blacklistReason" -> nodeBlacklistReasonToJson(nodeBlacklisted.reason)) + } + + def nodeUnblacklistedToJson(nodeUnblacklisted: SparkListenerNodeUnblacklisted): JValue = { + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.nodeUnblacklisted) ~ + ("hostId" -> nodeUnblacklisted.hostId) ~ + ("time" -> nodeUnblacklisted.time) ~ + ("unblacklistReason" -> nodeUnblacklistReasonToJson(nodeUnblacklisted.reason)) + } + /** ------------------------------------------------------------------- * * JSON serialization methods for classes SparkListenerEvents depend on | * -------------------------------------------------------------------- */ @@ -407,6 +425,24 @@ private[spark] object JsonProtocol { ("Reason" -> reason) ~ json } + def nodeBlacklistReasonToJson(nodeBlacklistReason: NodeBlacklistReason): JValue = { + val reason = Utils.getFormattedClassName(nodeBlacklistReason) + val json: JObject = nodeBlacklistReason match { + case ExecutorFailures(blacklistedExecutors) => + ("blacklistedExecutors" -> blacklistedExecutors) + case NodeDecommissioning => + Utils.emptyJson + case FetchFailure(host) => + ("host" -> host) + } + ("reason" -> reason) ~ json + } + + def nodeUnblacklistReasonToJson(nodeUnblacklistReason: NodeUnblacklistReason): JValue = { + val reason = Utils.getFormattedClassName(nodeUnblacklistReason) + "reason" -> reason + } + def blockManagerIdToJson(blockManagerId: BlockManagerId): JValue = { ("Executor ID" -> blockManagerId.executorId) ~ ("Host" -> blockManagerId.host) ~ @@ -515,6 +551,8 @@ private[spark] object JsonProtocol { val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) val logStart = Utils.getFormattedClassName(SparkListenerLogStart) val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) + val nodeBlacklisted = Utils.getFormattedClassName(SparkListenerNodeBlacklisted) + val nodeUnblacklisted = Utils.getFormattedClassName(SparkListenerNodeUnblacklisted) } def sparkEventFromJson(json: JValue): SparkListenerEvent = { @@ -538,6 +576,8 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case `nodeBlacklisted` => nodeBlacklistedFromJson(json) + case `nodeUnblacklisted` => nodeUnBlacklistedFromJson(json) case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) .asInstanceOf[SparkListenerEvent] } @@ -676,6 +716,20 @@ private[spark] object JsonProtocol { SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates) } + def nodeBlacklistedFromJson(json: JValue): SparkListenerNodeBlacklisted = { + val host = (json \ "hostId").extract[String] + val time = (json \ "time").extract[Long] + val reason = nodeBlacklistReasonFromJson(json \ "blacklistReason") + SparkListenerNodeBlacklisted(time, host, reason) + } + + def nodeUnBlacklistedFromJson(json: JValue): SparkListenerNodeUnblacklisted = { + val host = (json \ "hostId").extract[String] + val time = (json \ "time").extract[Long] + val reason = nodeUnblacklistReasonFromJson(json \ "unblacklistReason") + SparkListenerNodeUnblacklisted(time, host, reason) + } + /** --------------------------------------------------------------------- * * JSON deserialization methods for classes SparkListenerEvents depend on | * ---------------------------------------------------------------------- */ @@ -917,6 +971,42 @@ private[spark] object JsonProtocol { } } + private object NODE_BLACKLIST_REASON_FORMATTED_CLASS_NAMES { + val executorFailures = Utils.getFormattedClassName(ExecutorFailures) + val nodeDecommissioning = Utils.getFormattedClassName(NodeDecommissioning) + val fetchFailure = Utils.getFormattedClassName(FetchFailure) + } + + def nodeBlacklistReasonFromJson(json: JValue): NodeBlacklistReason = { + import NODE_BLACKLIST_REASON_FORMATTED_CLASS_NAMES._ + + (json \ "reason").extract[String] match { + case `executorFailures` => + val blacklistedExecutors = (json \ "blacklistedExecutors").extract[List[String]] + new ExecutorFailures(Set(blacklistedExecutors: _*)) + + case `nodeDecommissioning` => NodeDecommissioning + + case `fetchFailure` => + val host = (json \ "host").extract[String] + new FetchFailure(host) + } + } + + private object NODE_UNBLACKLIST_REASON_FORMATTED_CLASS_NAMES { + val blacklistTimedOut = Utils.getFormattedClassName(BlacklistTimedOut) + val nodeRunning = Utils.getFormattedClassName(NodeRunning) + } + + def nodeUnblacklistReasonFromJson(json: JValue): NodeUnblacklistReason = { + import NODE_UNBLACKLIST_REASON_FORMATTED_CLASS_NAMES._ + + (json \ "reason").extract[String] match { + case `blacklistTimedOut` => BlacklistTimedOut + case `nodeRunning` => NodeRunning + } + } + def blockManagerIdFromJson(json: JValue): BlockManagerId = { // On metadata fetch fail, block manager ID can be null (SPARK-4471) if (json == JNothing) { diff --git a/core/src/test/resources/spark-events/app-20161115172038-0000 b/core/src/test/resources/spark-events/app-20161115172038-0000 index 3af0451d0c39..f0e919b364b6 100755 --- a/core/src/test/resources/spark-events/app-20161115172038-0000 +++ b/core/src/test/resources/spark-events/app-20161115172038-0000 @@ -68,8 +68,8 @@ {"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1479252044931,"Job Result":{"Result":"JobSucceeded"}} {"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklisted","time":1479252044930,"executorId":"2","taskFailures":4} {"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklisted","time":1479252044930,"executorId":"0","taskFailures":4} -{"Event":"org.apache.spark.scheduler.SparkListenerNodeBlacklisted","time":1479252044930,"hostId":"172.22.0.111","executorFailures":2} +{"Event":"SparkListenerNodeBlacklisted","time":1479252044930,"hostId":"172.22.0.111","blacklistReason":{"reason":"ExecutorFailures","blacklistedExecutors":["exec1","exec2","exec3"]}} {"Event":"org.apache.spark.scheduler.SparkListenerExecutorUnblacklisted","time":1479252055635,"executorId":"2"} {"Event":"org.apache.spark.scheduler.SparkListenerExecutorUnblacklisted","time":1479252055635,"executorId":"0"} -{"Event":"org.apache.spark.scheduler.SparkListenerNodeUnblacklisted","time":1479252055635,"hostId":"172.22.0.111"} +{"Event":"SparkListenerNodeUnblacklisted","time":1479252055635,"hostId":"172.22.0.111","unblacklistReason":{"reason":"BlacklistTimedOut"}} {"Event":"SparkListenerApplicationEnd","Timestamp":1479252138874} diff --git a/core/src/test/resources/spark-events/app-20161116163331-0000 b/core/src/test/resources/spark-events/app-20161116163331-0000 index 57cfc5b97312..44b6554c6ce6 100755 --- a/core/src/test/resources/spark-events/app-20161116163331-0000 +++ b/core/src/test/resources/spark-events/app-20161116163331-0000 @@ -64,5 +64,5 @@ {"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1479335617480,"Job Result":{"Result":"JobSucceeded"}} {"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklisted","time":1479335617478,"executorId":"2","taskFailures":4} {"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklisted","time":1479335617478,"executorId":"0","taskFailures":4} -{"Event":"org.apache.spark.scheduler.SparkListenerNodeBlacklisted","time":1479335617478,"hostId":"172.22.0.167","executorFailures":2} +{"Event":"SparkListenerNodeBlacklisted","time":1479335617478,"hostId":"172.22.0.167","blacklistReason":{"reason":"ExecutorFailures","blacklistedExecutors":["exec1","exec2","exec3"]}} {"Event":"SparkListenerApplicationEnd","Timestamp":1479335620587} diff --git a/core/src/test/scala/org/apache/spark/HostStateSuite.scala b/core/src/test/scala/org/apache/spark/HostStateSuite.scala new file mode 100644 index 000000000000..95862ce127e7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/HostStateSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks.{forAll, Table} + +import org.apache.spark.HostState.HostState + +class HostStateSuite extends SparkFunSuite with Matchers { + + test("Contract for the conversion between YARN NodeState and HostState") { + val mappings = + Table( + ("yarnNodeState", "hostState"), + (HostState.toYarnState(HostState.New), HostState.New), + (HostState.toYarnState(HostState.Running), HostState.Running), + (HostState.toYarnState(HostState.Decommissioned), HostState.Decommissioned), + (HostState.toYarnState(HostState.Decommissioning), HostState.Decommissioning), + (HostState.toYarnState(HostState.Unhealthy), HostState.Unhealthy), + (HostState.toYarnState(HostState.Rebooted), HostState.Rebooted)) + + forAll (mappings) { (yarnNodeState: Option[String], hostState: HostState) => + assert(yarnNodeState.isDefined) + val hostStateOpt = HostState.fromYarnState(yarnNodeState.get) + assert(hostStateOpt.isDefined) + hostStateOpt.get should be (hostState) + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala index f6015cd51c2b..07136606f545 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala @@ -44,7 +44,8 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM // according to locality preferences, and so the job fails testScheduler("If preferred node is bad, without blacklist job will fail", extraConfs = Seq( - config.BLACKLIST_ENABLED.key -> "false" + config.BLACKLIST_ENABLED.key -> "false", + config.BLACKLIST_DECOMMISSIONING_ENABLED.key -> "true" )) { val rdd = new MockRDDWithLocalityPrefs(sc, 10, Nil, badHost) withBackend(badHostBackend _) { @@ -58,6 +59,7 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM "With default settings, job can succeed despite multiple bad executors on node", extraConfs = Seq( config.BLACKLIST_ENABLED.key -> "true", + config.BLACKLIST_DECOMMISSIONING_ENABLED.key -> "false", config.MAX_TASK_FAILURES.key -> "4", "spark.testing.nHosts" -> "2", "spark.testing.nExecutorsPerHost" -> "5", @@ -84,6 +86,7 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM "Bad node with multiple executors, job will still succeed with the right confs", extraConfs = Seq( config.BLACKLIST_ENABLED.key -> "true", + config.BLACKLIST_DECOMMISSIONING_ENABLED.key -> "false", // just to avoid this test taking too long "spark.locality.wait" -> "10ms" ) @@ -103,6 +106,7 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM "SPARK-15865 Progress with fewer executors than maxTaskFailures", extraConfs = Seq( config.BLACKLIST_ENABLED.key -> "true", + config.BLACKLIST_DECOMMISSIONING_ENABLED.key -> "false", "spark.testing.nHosts" -> "2", "spark.testing.nExecutorsPerHost" -> "1", "spark.testing.nCoresPerExecutor" -> "1" diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 520d85a29892..00139f5e1d57 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -41,6 +41,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M override def beforeEach(): Unit = { conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") + .set(config.BLACKLIST_DECOMMISSIONING_ENABLED.key, "false") scheduler = mockTaskSchedWithConf(conf) clock.setTime(0) @@ -188,7 +189,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) assert(blacklist.nodeBlacklist() === Set("hostA")) assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set("hostA")) - verify(listenerBusMock).post(SparkListenerNodeBlacklisted(0, "hostA", 2)) + verify(listenerBusMock).post(SparkListenerNodeBlacklisted(0, "hostA", + ExecutorFailures(Set("1", "2")))) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "2", 4)) @@ -202,7 +204,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) verify(listenerBusMock).post(SparkListenerExecutorUnblacklisted(timeout, "2")) verify(listenerBusMock).post(SparkListenerExecutorUnblacklisted(timeout, "1")) - verify(listenerBusMock).post(SparkListenerNodeUnblacklisted(timeout, "hostA")) + verify(listenerBusMock).post(SparkListenerNodeUnblacklisted(timeout, "hostA", + BlacklistTimedOut)) // Fail one more task, but executor isn't put back into blacklist since the count of failures // on that executor should have been reset to 0. @@ -248,7 +251,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M assert(blacklist.isExecutorBlacklisted("2")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(t1, "2", 4)) assert(blacklist.isNodeBlacklisted("hostA")) - verify(listenerBusMock).post(SparkListenerNodeBlacklisted(t1, "hostA", 2)) + verify(listenerBusMock).post(SparkListenerNodeBlacklisted(t1, "hostA", + ExecutorFailures(Set("1", "2")))) // Advance the clock so that executor 1 should no longer be explicitly blacklisted, but // everything else should still be blacklisted. @@ -266,7 +270,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M clock.advance(t1) blacklist.applyBlacklistTimeout() assert(!blacklist.nodeIdToBlacklistExpiryTime.contains("hostA")) - verify(listenerBusMock).post(SparkListenerNodeUnblacklisted(t1 + t2 + t1, "hostA")) + verify(listenerBusMock).post(SparkListenerNodeUnblacklisted(t1 + t2 + t1, "hostA", + BlacklistTimedOut)) // Even though unblacklisting a node implicitly unblacklists all of its executors, // there will be no SparkListenerExecutorUnblacklisted sent here. } @@ -401,14 +406,15 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2", "3")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "3", 2)) assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set("hostA")) - verify(listenerBusMock).post(SparkListenerNodeBlacklisted(0, "hostA", 2)) + verify(listenerBusMock).post(SparkListenerNodeBlacklisted(0, "hostA", + ExecutorFailures(Set("1", "3")))) } test("blacklist still respects legacy configs") { val conf = new SparkConf().setMaster("local") - assert(!BlacklistTracker.isBlacklistEnabled(conf)) + assert(!BlacklistTracker.isTaskExecutionBlacklistingEnabled(conf)) conf.set(config.BLACKLIST_LEGACY_TIMEOUT_CONF, 5000L) - assert(BlacklistTracker.isBlacklistEnabled(conf)) + assert(BlacklistTracker.isTaskExecutionBlacklistingEnabled(conf)) assert(5000 === BlacklistTracker.getBlacklistTimeout(conf)) // the new conf takes precedence, though conf.set(config.BLACKLIST_TIMEOUT_CONF, 1000L) @@ -416,10 +422,10 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // if you explicitly set the legacy conf to 0, that also would disable blacklisting conf.set(config.BLACKLIST_LEGACY_TIMEOUT_CONF, 0L) - assert(!BlacklistTracker.isBlacklistEnabled(conf)) + assert(!BlacklistTracker.isTaskExecutionBlacklistingEnabled(conf)) // but again, the new conf takes precedence conf.set(config.BLACKLIST_ENABLED, true) - assert(BlacklistTracker.isBlacklistEnabled(conf)) + assert(BlacklistTracker.isTaskExecutionBlacklistingEnabled(conf)) assert(1000 === BlacklistTracker.getBlacklistTimeout(conf)) } @@ -439,7 +445,9 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M s"( = ${maxTaskFailures} ). Though blacklisting is enabled, with this configuration, " + s"Spark will not be robust to one bad node. Decrease " + s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key}, increase ${config.MAX_TASK_FAILURES.key}, " + - s"or disable blacklisting with ${config.BLACKLIST_ENABLED.key}") + s"or disable blacklisting with ${config.BLACKLIST_ENABLED.key}, " + + s"${config.BLACKLIST_DECOMMISSIONING_ENABLED.key} " + + s"and ${config.BLACKLIST_FETCH_FAILURE_ENABLED.key}") } conf.remove(config.MAX_TASK_FAILURES) @@ -452,7 +460,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M config.MAX_FAILED_EXEC_PER_NODE_STAGE, config.MAX_FAILURES_PER_EXEC, config.MAX_FAILED_EXEC_PER_NODE, - config.BLACKLIST_TIMEOUT_CONF + config.BLACKLIST_TIMEOUT_CONF, + config.BLACKLIST_DECOMMISSIONING_TIMEOUT_CONF ).foreach { config => conf.set(config.key, "0") val excMsg = intercept[IllegalArgumentException] { @@ -585,4 +594,72 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M 2000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) } + + test("node is blacklisted with NodeDecommissioning reason and gets recovered with time") { + conf.set(config.BLACKLIST_DECOMMISSIONING_ENABLED.key, "true") + blacklist = new BlacklistTracker(listenerBusMock, conf, None, clock) + blacklist.addNodeToBlacklist("hostA", NodeDecommissioning) + assert(blacklist.nodeBlacklist() === Set("hostA")) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set("hostA")) + verify(listenerBusMock).post(SparkListenerNodeBlacklisted(0, "hostA", NodeDecommissioning)) + + val timeout = blacklist.BLACKLIST_DECOMMISSIONING_TIMEOUT_MILLIS + 1 + clock.advance(timeout) + blacklist.applyBlacklistTimeout() + assert(blacklist.nodeBlacklist() === Set()) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) + verify(listenerBusMock).post(SparkListenerNodeUnblacklisted(timeout, "hostA", + BlacklistTimedOut)) + } + + test("node is unblacklisted with NodeRunning reason") { + conf.set(config.BLACKLIST_DECOMMISSIONING_ENABLED.key, "true") + blacklist = new BlacklistTracker(listenerBusMock, conf, None, clock) + val now = clock.getTimeMillis() + blacklist.addNodeToBlacklist("hostA", NodeDecommissioning) + blacklist.addNodeToBlacklist("hostB", ExecutorFailures(Set())) + assert(blacklist.nodeBlacklist() === Set("hostA", "hostB")) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set("hostA", "hostB")) + + blacklist.removeNodesFromBlacklist(List(("hostA", NodeRunning, Some(now)))) + assert(blacklist.nodeBlacklist() === Set("hostB")) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set("hostB")) + verify(listenerBusMock).post(SparkListenerNodeUnblacklisted(now, "hostA", NodeRunning)) + } + + (for { + taskExecutionBlacklistingEnabled <- Seq(true, false) + decommissioningBlacklistingEnabled <- Seq(true, false) + } yield (taskExecutionBlacklistingEnabled, decommissioningBlacklistingEnabled)).foreach { + case (taskExecutionBlacklistingEnabled, decommissioningBlacklistingEnabled) => + val blacklistStatusMsgDict = Map(true -> "enforced", false -> "ignored") + + test(s"task execution blacklisting is " + + s"${blacklistStatusMsgDict(taskExecutionBlacklistingEnabled)} due to " + + s"${config.BLACKLIST_ENABLED.key}=$taskExecutionBlacklistingEnabled, while " + + s"decommissioning blacklisting is " + + s"${blacklistStatusMsgDict(decommissioningBlacklistingEnabled)} " + + s"due to " + + s"${config.BLACKLIST_DECOMMISSIONING_ENABLED.key}=$decommissioningBlacklistingEnabled") { + conf = conf.set(config.BLACKLIST_ENABLED.key, taskExecutionBlacklistingEnabled.toString). + set(config.BLACKLIST_DECOMMISSIONING_ENABLED.key, + decommissioningBlacklistingEnabled.toString) + blacklist = new BlacklistTracker(listenerBusMock, conf, None, clock) + + val (failingHost, decommissioningHost) = ("hostFailing", "hostDecommissioning") + blacklist.addNodeToBlacklist(failingHost, ExecutorFailures(Set())) + blacklist.addNodeToBlacklist(decommissioningHost, NodeDecommissioning) + val blacklistedHosts = (if (taskExecutionBlacklistingEnabled) { + Set(failingHost) + } else { + Set[String]() + }) ++ (if (decommissioningBlacklistingEnabled) { + Set(decommissioningHost) + } else { + Set[String]() + }) + assert(blacklist.nodeBlacklist() === blacklistedHosts) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), blacklistedHosts) + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 04cccc67e328..effc788d454b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -17,13 +17,46 @@ package org.apache.spark.scheduler -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.mockito.Mockito.{verify, when} +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{HostState, LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.{RpcUtils, SerializableBuffer} -class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { +class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with MockitoSugar + with LocalSparkContext { + + private var conf: SparkConf = _ + private var scheduler: TaskSchedulerImpl = _ + private var schedulerBackend: CoarseGrainedSchedulerBackend = _ + + override def beforeEach(): Unit = { + super.beforeEach() + conf = new SparkConf + } + + override def afterEach(): Unit = { + super.afterEach() + if (scheduler != null) { + scheduler.stop() + scheduler = null + } + if (schedulerBackend != null) { + schedulerBackend.stop() + schedulerBackend = null + } + } + + private def setupSchedulerBackend(): Unit = { + sc = new SparkContext("local", "test", conf) + scheduler = mock[TaskSchedulerImpl] + when(scheduler.sc).thenReturn(sc) + schedulerBackend = new CoarseGrainedSchedulerBackend(scheduler, mock[RpcEnv]) + } test("serialized task larger than max RPC message size") { - val conf = new SparkConf conf.set("spark.rpc.message.maxSize", "1") conf.set("spark.default.parallelism", "1") sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf) @@ -38,4 +71,13 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo assert(smaller.size === 4) } + test("handle updated node status received") { + setupSchedulerBackend() + schedulerBackend.handleUpdatedHostState("host1", HostState.Decommissioning) + verify(scheduler).blacklistExecutorsOnHost("host1", NodeDecommissioning) + + schedulerBackend.handleUpdatedHostState("host1", HostState.Running) + verify(scheduler).unblacklistExecutorsOnHost("host1", NodeRunning) + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index ab67a393e2ac..8c631d1d466b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -85,6 +85,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B blacklist = mock[BlacklistTracker] val conf = new SparkConf().setMaster("local").setAppName("TaskSchedulerImplSuite") conf.set(config.BLACKLIST_ENABLED, true) + .set(config.BLACKLIST_DECOMMISSIONING_ENABLED, false) sc = new SparkContext(conf) taskScheduler = new TaskSchedulerImpl(sc, sc.conf.getInt("spark.task.maxFailures", 4)) { @@ -621,7 +622,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // schedulable on another executor. However, that executor may fail later on, leaving the // first task with no place to run. val taskScheduler = setupScheduler( - config.BLACKLIST_ENABLED.key -> "true" + config.BLACKLIST_ENABLED.key -> "true", + config.BLACKLIST_DECOMMISSIONING_ENABLED.key -> "false" ) val taskSet = FakeTask.createTaskSet(2) @@ -672,7 +674,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // available and not bail on the job val taskScheduler = setupScheduler( - config.BLACKLIST_ENABLED.key -> "true" + config.BLACKLIST_ENABLED.key -> "true", + config.BLACKLIST_DECOMMISSIONING_ENABLED.key -> "false" ) val taskSet = FakeTask.createTaskSet(2, (0 until 2).map { _ => Seq(TaskLocation("host0")) }: _*) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala index f1392e9db6bf..44a7d6092925 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala @@ -25,6 +25,7 @@ class TaskSetBlacklistSuite extends SparkFunSuite { test("Blacklisting tasks, executors, and nodes") { val conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") + .set(config.BLACKLIST_DECOMMISSIONING_ENABLED.key, "false") val clock = new ManualClock val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, clock = clock) @@ -146,6 +147,7 @@ class TaskSetBlacklistSuite extends SparkFunSuite { // lead to any node blacklisting val conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") + .set(config.BLACKLIST_DECOMMISSIONING_ENABLED.key, "false") val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 6f1663b21096..425db1159892 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -417,102 +417,108 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } } - test("executors should be blacklisted after task failure, in spite of locality preferences") { - val rescheduleDelay = 300L - val conf = new SparkConf(). - set(config.BLACKLIST_ENABLED, true). - set(config.BLACKLIST_TIMEOUT_CONF, rescheduleDelay). - // don't wait to jump locality levels in this test - set("spark.locality.wait", "0") - - sc = new SparkContext("local", "test", conf) - // two executors on same host, one on different. - sched = new FakeTaskScheduler(sc, ("exec1", "host1"), - ("exec1.1", "host1"), ("exec2", "host2")) - // affinity to exec1 on host1 - which we will fail. - val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "exec1"))) - val clock = new ManualClock - clock.advance(1) - // We don't directly use the application blacklist, but its presence triggers blacklisting - // within the taskset. - val mockListenerBus = mock(classOf[LiveListenerBus]) - val blacklistTrackerOpt = Some(new BlacklistTracker(mockListenerBus, conf, None, clock)) - val manager = new TaskSetManager(sched, taskSet, 4, blacklistTrackerOpt, clock) - - { - val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) - assert(offerResult.isDefined, "Expect resource offer to return a task") + List(true, false).foreach { decommissioningBlacklistingEnabled => + val blacklistStatusMsgDict = Map(true -> "enabled", false -> "disabled") + test("executors should be blacklisted after task failure, in spite of locality preferences " + + s"and decommissioning blacklisting " + + s"being ${blacklistStatusMsgDict(decommissioningBlacklistingEnabled)}") { + val rescheduleDelay = 300L + val conf = new SparkConf(). + set(config.BLACKLIST_ENABLED, true). + set(config.BLACKLIST_DECOMMISSIONING_ENABLED, decommissioningBlacklistingEnabled). + set(config.BLACKLIST_TIMEOUT_CONF, rescheduleDelay). + // don't wait to jump locality levels in this test + set("spark.locality.wait", "0") + + sc = new SparkContext("local", "test", conf) + // two executors on same host, one on different. + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + ("exec1.1", "host1"), ("exec2", "host2")) + // affinity to exec1 on host1 - which we will fail. + val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "exec1"))) + val clock = new ManualClock + clock.advance(1) + // We don't directly use the application blacklist, but its presence triggers blacklisting + // within the taskset. + val mockListenerBus = mock(classOf[LiveListenerBus]) + val blacklistTrackerOpt = Some(new BlacklistTracker(mockListenerBus, conf, None, clock)) + val manager = new TaskSetManager(sched, taskSet, 4, blacklistTrackerOpt, clock) + + { + val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) + assert(offerResult.isDefined, "Expect resource offer to return a task") + + assert(offerResult.get.index === 0) + assert(offerResult.get.executorId === "exec1") + + // Cause exec1 to fail : failure 1 + manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost) + assert(!sched.taskSetsFailed.contains(taskSet.id)) - assert(offerResult.get.index === 0) - assert(offerResult.get.executorId === "exec1") + // Ensure scheduling on exec1 fails after failure 1 due to blacklist + assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).isEmpty) + assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).isEmpty) + assert(manager.resourceOffer("exec1", "host1", RACK_LOCAL).isEmpty) + assert(manager.resourceOffer("exec1", "host1", ANY).isEmpty) + } - // Cause exec1 to fail : failure 1 - manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost) - assert(!sched.taskSetsFailed.contains(taskSet.id)) + // Run the task on exec1.1 - should work, and then fail it on exec1.1 + { + val offerResult = manager.resourceOffer("exec1.1", "host1", NODE_LOCAL) + assert(offerResult.isDefined, + "Expect resource offer to return a task for exec1.1, offerResult = " + offerResult) - // Ensure scheduling on exec1 fails after failure 1 due to blacklist - assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).isEmpty) - assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).isEmpty) - assert(manager.resourceOffer("exec1", "host1", RACK_LOCAL).isEmpty) - assert(manager.resourceOffer("exec1", "host1", ANY).isEmpty) - } + assert(offerResult.get.index === 0) + assert(offerResult.get.executorId === "exec1.1") - // Run the task on exec1.1 - should work, and then fail it on exec1.1 - { - val offerResult = manager.resourceOffer("exec1.1", "host1", NODE_LOCAL) - assert(offerResult.isDefined, - "Expect resource offer to return a task for exec1.1, offerResult = " + offerResult) + // Cause exec1.1 to fail : failure 2 + manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost) + assert(!sched.taskSetsFailed.contains(taskSet.id)) - assert(offerResult.get.index === 0) - assert(offerResult.get.executorId === "exec1.1") + // Ensure scheduling on exec1.1 fails after failure 2 due to blacklist + assert(manager.resourceOffer("exec1.1", "host1", NODE_LOCAL).isEmpty) + } - // Cause exec1.1 to fail : failure 2 - manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost) - assert(!sched.taskSetsFailed.contains(taskSet.id)) + // Run the task on exec2 - should work, and then fail it on exec2 + { + val offerResult = manager.resourceOffer("exec2", "host2", ANY) + assert(offerResult.isDefined, "Expect resource offer to return a task") - // Ensure scheduling on exec1.1 fails after failure 2 due to blacklist - assert(manager.resourceOffer("exec1.1", "host1", NODE_LOCAL).isEmpty) - } + assert(offerResult.get.index === 0) + assert(offerResult.get.executorId === "exec2") - // Run the task on exec2 - should work, and then fail it on exec2 - { - val offerResult = manager.resourceOffer("exec2", "host2", ANY) - assert(offerResult.isDefined, "Expect resource offer to return a task") + // Cause exec2 to fail : failure 3 + manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost) + assert(!sched.taskSetsFailed.contains(taskSet.id)) - assert(offerResult.get.index === 0) - assert(offerResult.get.executorId === "exec2") + // Ensure scheduling on exec2 fails after failure 3 due to blacklist + assert(manager.resourceOffer("exec2", "host2", ANY).isEmpty) + } - // Cause exec2 to fail : failure 3 - manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost) - assert(!sched.taskSetsFailed.contains(taskSet.id)) + // Despite advancing beyond the time for expiring executors from within the blacklist, + // we *never* expire from *within* the stage blacklist + clock.advance(rescheduleDelay) - // Ensure scheduling on exec2 fails after failure 3 due to blacklist - assert(manager.resourceOffer("exec2", "host2", ANY).isEmpty) - } + { + val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) + assert(offerResult.isEmpty) + } - // Despite advancing beyond the time for expiring executors from within the blacklist, - // we *never* expire from *within* the stage blacklist - clock.advance(rescheduleDelay) + { + val offerResult = manager.resourceOffer("exec3", "host3", ANY) + assert(offerResult.isDefined) + assert(offerResult.get.index === 0) + assert(offerResult.get.executorId === "exec3") - { - val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) - assert(offerResult.isEmpty) - } + assert(manager.resourceOffer("exec3", "host3", ANY).isEmpty) - { - val offerResult = manager.resourceOffer("exec3", "host3", ANY) - assert(offerResult.isDefined) - assert(offerResult.get.index === 0) - assert(offerResult.get.executorId === "exec3") - - assert(manager.resourceOffer("exec3", "host3", ANY).isEmpty) + // Cause exec3 to fail : failure 4 + manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost) + } - // Cause exec3 to fail : failure 4 - manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost) + // we have failed the same task 4 times now : task id should now be in taskSetsFailed + assert(sched.taskSetsFailed.contains(taskSet.id)) } - - // we have failed the same task 4 times now : task id should now be in taskSetsFailed - assert(sched.taskSetsFailed.contains(taskSet.id)) } test("new executors get added and lost") { @@ -1100,44 +1106,85 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager3.name === "TaskSet_1.1") } - test("don't update blacklist for shuffle-fetch failures, preemption, denied commits, " + - "or killed tasks") { - // Setup a taskset, and fail some tasks for a fetch failure, preemption, denied commit, - // and killed task. + List(true, false).foreach { decommissioningBlacklistingEnabled => + val blacklistStatusMsgDict = Map(true -> "enabled", false -> "disabled") + test("don't update blacklist for shuffle-fetch failures, preemption, denied commits, " + + "or killed tasks, in spite of decommissioning blacklisting " + + s"being ${blacklistStatusMsgDict(decommissioningBlacklistingEnabled)}") { + // Setup a taskset, and fail some tasks for a fetch failure, preemption, denied commit, + // and killed task. + val conf = new SparkConf(). + set(config.BLACKLIST_ENABLED, true). + set(config.BLACKLIST_DECOMMISSIONING_ENABLED, decommissioningBlacklistingEnabled) + sc = new SparkContext("local", "test", conf) + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(4) + val tsm = new TaskSetManager(sched, taskSet, 4) + // we need a spy so we can attach our mock blacklist + val tsmSpy = spy(tsm) + val blacklist = mock(classOf[TaskSetBlacklist]) + when(tsmSpy.taskSetBlacklistHelperOpt).thenReturn(Some(blacklist)) + + // make some offers to our taskset, to get tasks we will fail + val taskDescs = Seq( + "exec1" -> "host1", + "exec2" -> "host1" + ).flatMap { case (exec, host) => + // offer each executor twice (simulating 2 cores per executor) + (0 until 2).flatMap{ _ => tsmSpy.resourceOffer(exec, host, TaskLocality.ANY)} + } + assert(taskDescs.size === 4) + + // now fail those tasks + tsmSpy.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED, + FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored")) + tsmSpy.handleFailedTask(taskDescs(1).taskId, TaskState.FAILED, + ExecutorLostFailure(taskDescs(1).executorId, exitCausedByApp = false, reason = None)) + tsmSpy.handleFailedTask(taskDescs(2).taskId, TaskState.FAILED, + TaskCommitDenied(0, 2, 0)) + tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, TaskKilled("test")) + + // Make sure that the blacklist ignored all of the task failures above, since they aren't + // the fault of the executor where the task was running. + verify(blacklist, never()) + .updateBlacklistForFailedTask(anyString(), anyString(), anyInt()) + } + } + + test("don't update blacklist for successful task sets when task execution blacklisting is " + + "disabled, in spite of having decommissioning blacklisting enabled") { val conf = new SparkConf(). - set(config.BLACKLIST_ENABLED, true) + set(config.BLACKLIST_ENABLED, false). + set(config.BLACKLIST_DECOMMISSIONING_ENABLED, true) + sc = new SparkContext("local", "test", conf) sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) - val taskSet = FakeTask.createTaskSet(4) - val tsm = new TaskSetManager(sched, taskSet, 4) - // we need a spy so we can attach our mock blacklist - val tsmSpy = spy(tsm) - val blacklist = mock(classOf[TaskSetBlacklist]) - when(tsmSpy.taskSetBlacklistHelperOpt).thenReturn(Some(blacklist)) + val taskSet = FakeTask.createTaskSet(1) + val clock = new ManualClock + clock.advance(1) + val mockListenerBus = mock(classOf[LiveListenerBus]) + // to simulate BLACKLIST_DECOMMISSIONING_ENABLED=true + val blacklistTrackerOpt = Some(spy(new BlacklistTracker(mockListenerBus, conf, None, clock))) + val tsm = new TaskSetManager(sched, taskSet, 4, blacklistTrackerOpt, clock) - // make some offers to our taskset, to get tasks we will fail + assert(tsm.taskSetBlacklistHelperOpt.isEmpty) + // make some offers to our taskset val taskDescs = Seq( "exec1" -> "host1", "exec2" -> "host1" ).flatMap { case (exec, host) => // offer each executor twice (simulating 2 cores per executor) - (0 until 2).flatMap{ _ => tsmSpy.resourceOffer(exec, host, TaskLocality.ANY)} + (0 until 2).flatMap{ _ => tsm.resourceOffer(exec, host, TaskLocality.ANY)} } - assert(taskDescs.size === 4) - // now fail those tasks - tsmSpy.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED, - FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored")) - tsmSpy.handleFailedTask(taskDescs(1).taskId, TaskState.FAILED, - ExecutorLostFailure(taskDescs(1).executorId, exitCausedByApp = false, reason = None)) - tsmSpy.handleFailedTask(taskDescs(2).taskId, TaskState.FAILED, - TaskCommitDenied(0, 2, 0)) - tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, TaskKilled("test")) - - // Make sure that the blacklist ignored all of the task failures above, since they aren't - // the fault of the executor where the task was running. - verify(blacklist, never()) - .updateBlacklistForFailedTask(anyString(), anyString(), anyInt()) + val directTaskResult = new DirectTaskResult[String](null, Seq()) { + override def value(resultSer: SerializerInstance): String = "" + } + tsm.handleSuccessfulTask(taskDescs(0).taskId, directTaskResult) + tsm.abort("test") + + verify(blacklistTrackerOpt.get, never()) + .updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), any()) } test("update application blacklist for shuffle-fetch") { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 46aa9c37986c..f9be45426218 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -32,6 +32,7 @@ class KryoSerializerDistributedSuite extends SparkFunSuite with LocalSparkContex .set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) .set(config.MAX_TASK_FAILURES, 1) .set(config.BLACKLIST_ENABLED, false) + .set(config.BLACKLIST_DECOMMISSIONING_ENABLED, false) val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) conf.setJars(List(jar.getPath)) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a1a858765a7d..908d9325d157 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -21,6 +21,7 @@ import java.util.Properties import scala.collection.JavaConverters._ import scala.collection.Map +import scala.collection.immutable.SortedSet import org.json4s.JsonAST.{JArray, JInt, JString, JValue} import org.json4s.JsonDSL._ @@ -85,9 +86,10 @@ class JsonProtocolSuite extends SparkFunSuite { val executorBlacklisted = SparkListenerExecutorBlacklisted(executorBlacklistedTime, "exec1", 22) val executorUnblacklisted = SparkListenerExecutorUnblacklisted(executorUnblacklistedTime, "exec1") - val nodeBlacklisted = SparkListenerNodeBlacklisted(nodeBlacklistedTime, "node1", 33) + val nodeBlacklisted = SparkListenerNodeBlacklisted(nodeBlacklistedTime, "host1", + ExecutorFailures(SortedSet("exec1", "exec2", "exec3"))) val nodeUnblacklisted = - SparkListenerNodeUnblacklisted(nodeUnblacklistedTime, "node1") + SparkListenerNodeUnblacklisted(nodeUnblacklistedTime, "host1", BlacklistTimedOut) val executorMetricsUpdate = { // Use custom accum ID for determinism val accumUpdates = @@ -169,6 +171,14 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(ExecutorLostFailure("100", true, Some("Induced failure"))) testTaskEndReason(UnknownReason) + // NodeBlacklistReason + testNodeBlacklistReason(ExecutorFailures(SortedSet("exec1", "exec2", "exec3"))) + testNodeBlacklistReason(NodeDecommissioning) + + // NodeUnblacklistReason + testNodeUnblacklistReason(BlacklistTimedOut) + testNodeUnblacklistReason(NodeRunning) + // BlockId testBlockId(RDDBlockId(1, 2)) testBlockId(ShuffleBlockId(1, 2, 3)) @@ -494,6 +504,18 @@ private[spark] object JsonProtocolSuite extends Assertions { assertEquals(reason, newReason) } + private def testNodeBlacklistReason(reason: NodeBlacklistReason) { + val newReason = JsonProtocol.nodeBlacklistReasonFromJson( + JsonProtocol.nodeBlacklistReasonToJson(reason)) + assertEquals(reason, newReason) + } + + private def testNodeUnblacklistReason(reason: NodeUnblacklistReason) { + val newReason = JsonProtocol.nodeUnblacklistReasonFromJson( + JsonProtocol.nodeUnblacklistReasonToJson(reason)) + assertEquals(reason, newReason) + } + private def testBlockId(blockId: BlockId) { val newBlockId = BlockId(blockId.toString) assert(blockId === newBlockId) @@ -548,6 +570,14 @@ private[spark] object JsonProtocolSuite extends Assertions { assertEquals(e1.executorInfo, e2.executorInfo) case (e1: SparkListenerExecutorRemoved, e2: SparkListenerExecutorRemoved) => assert(e1.executorId === e1.executorId) + case (e1: SparkListenerNodeBlacklisted, e2: SparkListenerNodeBlacklisted) => + assert(e1.hostId === e2.hostId) + assert(e1.time === e2.time) + assertEquals(e1.reason, e2.reason) + case (e1: SparkListenerNodeUnblacklisted, e2: SparkListenerNodeUnblacklisted) => + assert(e1.hostId === e2.hostId) + assert(e1.time === e2.time) + assert(e1.reason === e2.reason) case (e1: SparkListenerExecutorMetricsUpdate, e2: SparkListenerExecutorMetricsUpdate) => assert(e1.execId === e2.execId) assertSeqEquals[(Long, Int, Int, Seq[AccumulableInfo])]( @@ -693,6 +723,25 @@ private[spark] object JsonProtocolSuite extends Assertions { } } + private def assertEquals(reason1: NodeBlacklistReason, reason2: NodeBlacklistReason) { + (reason1, reason2) match { + case (NodeDecommissioning, NodeDecommissioning) => + case (ExecutorFailures(blacklistedExecutors1), ExecutorFailures(blacklistedExecutors2)) => + assert(blacklistedExecutors1 === blacklistedExecutors2) + case (FetchFailure(host1), FetchFailure(host2)) => + assert(host1 === host2) + case _ => fail("Node blacklist reasons don't match in types!") + } + } + + private def assertEquals(reason1: NodeUnblacklistReason, reason2: NodeUnblacklistReason) { + (reason1, reason2) match { + case (NodeRunning, NodeRunning) => + case (BlacklistTimedOut, BlacklistTimedOut) => + case _ => fail("Node unblacklist reasons don't match in types!") + } + } + private def assertEquals( details1: Map[String, Seq[(String, String)]], details2: Map[String, Seq[(String, String)]]) { @@ -2027,18 +2076,24 @@ private[spark] object JsonProtocolSuite extends Assertions { private val nodeBlacklistedJsonString = s""" |{ - | "Event" : "org.apache.spark.scheduler.SparkListenerNodeBlacklisted", + | "Event" : "SparkListenerNodeBlacklisted", + | "hostId" : "host1", | "time" : ${nodeBlacklistedTime}, - | "hostId" : "node1", - | "executorFailures" : 33 + | "blacklistReason" : { + | "reason" : "ExecutorFailures", + | "blacklistedExecutors" : [ "exec1", "exec2", "exec3" ] + | } |} """.stripMargin private val nodeUnblacklistedJsonString = s""" |{ - | "Event" : "org.apache.spark.scheduler.SparkListenerNodeUnblacklisted", + | "Event" : "SparkListenerNodeUnblacklisted", + | "hostId" : "host1", | "time" : ${nodeUnblacklistedTime}, - | "hostId" : "node1" + | "unblacklistReason" : { + | "reason" : "BlacklistTimedOut" + | } |} """.stripMargin } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 7052fb347106..620f08d5cdef 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -32,15 +32,15 @@ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.spark.{SecurityManager, SparkConf, SparkException} +import org.apache.spark.{HostState, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RetrieveLastAllocatedExecutorId +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{HostStatusUpdate, +RemoveExecutor, RetrieveLastAllocatedExecutorId} import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} /** @@ -266,6 +266,23 @@ private[yarn] class YarnAllocator( // requests. val allocateResponse = amClient.allocate(progressIndicator) + val updatedNodeReports = allocateResponse.getUpdatedNodes + + updatedNodeReports.asScala.foreach(nodeReport => { + logInfo("Yarn node state updated for host %s to %s" + .format(nodeReport.getNodeId.getHost, nodeReport.getNodeState.name)) + + val hostState = HostState.fromYarnState(nodeReport.getNodeState.name) + hostState match { + case Some(state) => + driverRef.send(HostStatusUpdate(nodeReport.getNodeId.getHost, state)) + + case None => + logWarning("Cannot find Host state corresponding to YARN node state %s" + .format(nodeReport.getNodeState.name)) + } + }) + val allocatedContainers = allocateResponse.getAllocatedContainers() if (allocatedContainers.size > 0) { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 415a29fd887e..6acdb3d3aad3 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -264,6 +264,10 @@ private[spark] abstract class YarnSchedulerBackend( case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) + case HostStatusUpdate(host, hostState) => + logDebug("Received updated state %s for host %s".format(host, hostState)) + handleUpdatedHostState(host, hostState) + case r @ RemoveExecutor(executorId, reason) => logWarning(reason.toString) driverEndpoint.ask[Boolean](r).onFailure { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index 9c3b18e4ec5f..972849c9b8a7 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -28,10 +28,12 @@ import scala.language.postfixOps import com.google.common.io.Files import org.apache.commons.lang3.SerializationUtils +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster import org.scalatest.{BeforeAndAfterAll, Matchers} import org.scalatest.concurrent.Eventually._ +import org.scalatest.time import org.apache.spark._ import org.apache.spark.deploy.yarn.config._ @@ -78,6 +80,28 @@ abstract class BaseYarnClusterSuite val logConfFile = new File(logConfDir, "log4j.properties") Files.write(LOG4J_CONF, logConfFile, StandardCharsets.UTF_8) + restartCluster() + + fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) + hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) + assert(hadoopConfDir.mkdir()) + File.createTempFile("token", ".txt", hadoopConfDir) + } + + override def afterAll() { + try { + yarnCluster.stop() + } finally { + System.setProperties(oldSystemProperties) + super.afterAll() + } + } + + protected def restartCluster(): Unit = { + if (yarnCluster != null) { + yarnCluster.stop() + } + // Disable the disk utilization check to avoid the test hanging when people's disks are // getting full. val yarnConf = newYarnConfig() @@ -113,20 +137,6 @@ abstract class BaseYarnClusterSuite } logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") - - fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) - hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) - assert(hadoopConfDir.mkdir()) - File.createTempFile("token", ".txt", hadoopConfDir) - } - - override def afterAll() { - try { - yarnCluster.stop() - } finally { - System.setProperties(oldSystemProperties) - super.afterAll() - } } protected def runSpark( @@ -137,7 +147,9 @@ abstract class BaseYarnClusterSuite extraClassPath: Seq[String] = Nil, extraJars: Seq[String] = Nil, extraConf: Map[String, String] = Map(), - extraEnv: Map[String, String] = Map()): SparkAppHandle.State = { + extraEnv: Map[String, String] = Map(), + numExecutors: Int = 1, + executionTimeout: time.Span = 2 minutes): SparkAppHandle.State = { val deployMode = if (clientMode) "client" else "cluster" val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf) val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv @@ -152,7 +164,7 @@ abstract class BaseYarnClusterSuite launcher.setSparkHome(sys.props("spark.test.home")) .setMaster("yarn") .setDeployMode(deployMode) - .setConf("spark.executor.instances", "1") + .setConf("spark.executor.instances", numExecutors.toString) .setPropertiesFile(propsFile) .addAppArgs(appArgs.toArray: _*) @@ -167,7 +179,7 @@ abstract class BaseYarnClusterSuite val handle = launcher.startApplication() try { - eventually(timeout(2 minutes), interval(1 second)) { + eventually(timeout(executionTimeout), interval(1 second)) { assert(handle.getState().isFinal()) } } finally { @@ -238,4 +250,30 @@ abstract class BaseYarnClusterSuite propsFile.getAbsolutePath() } + protected def getClusterWorkDir: File = yarnCluster.getTestWorkDir + + /** + * Gracefully decommissions the only node in the mini YARN cluster, if that + * functionality is available in the Hadoop version that it is configured. + * Throws an exception in case the decommissioning functionality is not available. + * @return the host that will be decommissioned. + */ + protected def gracefullyDecommissionNode(conf: Configuration, + excludedHostsFile: File, + decommissionTimeout: time.Span): String = { + val resourceManager = yarnCluster.getResourceManager + val nodesListManager = resourceManager.getRMContext.getNodesListManager + val clusterNodes = Map(resourceManager.getRMContext.getRMNodes.asScala.toSeq : _ *) + assert(!clusterNodes.isEmpty) + // note the MiniYARNCluster will always have a single node + val hostToExclude = clusterNodes.keysIterator.next().getHost + Files.append(s"$hostToExclude ${decommissionTimeout.toSeconds}${sys.props("line.separator")}", + excludedHostsFile, StandardCharsets.UTF_8) + // use reflection so this compiles for other YARN versions, but fails with a + // reflection exception if executed with incompatible versions of YARN + nodesListManager.getClass + .getMethod("refreshNodes", classOf[Configuration], java.lang.Boolean.TYPE) + .invoke(nodesListManager, conf, true.asInstanceOf[java.lang.Object]) + hostToExclude + } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index cb1e3c526851..3e7ae2ae95c5 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -20,19 +20,22 @@ package org.apache.spark.deploy.yarn import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.mockito.{Matchers => MockitoMatchers} import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterEach, Matchers} -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.{HostState, SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.yarn.YarnAllocator._ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.HostStatusUpdate import org.apache.spark.util.ManualClock class MockResolver extends SparkRackResolver { @@ -83,7 +86,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter def createAllocator( maxExecutors: Int = 5, - rmClient: AMRMClient[ContainerRequest] = rmClient): YarnAllocator = { + rmClient: AMRMClient[ContainerRequest] = rmClient, + driverRef: RpcEndpointRef = mock(classOf[RpcEndpointRef])): YarnAllocator = { val args = Array( "--jar", "somejar.jar", "--class", "SomeClass") @@ -94,7 +98,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter .set("spark.executor.memory", "2048") new YarnAllocator( "not used", - mock(classOf[RpcEndpointRef]), + driverRef, conf, sparkConfClone, rmClient, @@ -350,4 +354,37 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter clock.advance(50 * 1000L) handler.getNumExecutorsFailed should be (0) } + + test("HostStatusUpdate signal on YARN node state change") { + val mockAmClient = mock(classOf[AMRMClient[ContainerRequest]]) + val mockAllocateResponse = mock(classOf[AllocateResponse]) + val mockNodeReport1 = mock(classOf[NodeReport]) + val mockNodeReport2 = mock(classOf[NodeReport]) + val mockNodeId1 = mock(classOf[NodeId]) + val mockNodeId2 = mock(classOf[NodeId]) + + val nodeState1 = HostState.toYarnState(HostState.Decommissioning) + assert(nodeState1.isDefined) + val nodeState2 = HostState.toYarnState(HostState.Running) + assert(nodeState2.isDefined) + + when(mockNodeId1.getHost).thenReturn("host1") + when(mockNodeId2.getHost).thenReturn("host2") + when(mockNodeReport1.getNodeState).thenReturn(NodeState.valueOf(nodeState1.get)) + when(mockNodeReport2.getNodeState).thenReturn(NodeState.valueOf(nodeState2.get)) + when(mockNodeReport1.getNodeId).thenReturn(mockNodeId1) + when(mockNodeReport2.getNodeId).thenReturn(mockNodeId2) + + when(mockAllocateResponse.getUpdatedNodes).thenReturn(List(mockNodeReport1, + mockNodeReport2).asJava) + when(mockAmClient.allocate(MockitoMatchers.anyFloat())).thenReturn(mockAllocateResponse) + + val driverRef = mock(classOf[RpcEndpointRef]) + val handler = createAllocator(4, mockAmClient, driverRef) + + handler.allocateResources() + + verify(driverRef).send(HostStatusUpdate("host1", HostState.Decommissioning)) + verify(driverRef).send(HostStatusUpdate("host2", HostState.Running)) + } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnDecommissioningSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnDecommissioningSuite.scala new file mode 100644 index 000000000000..b714f4a0e0bb --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnDecommissioningSuite.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.{Files => JFiles, Paths} +import java.util.concurrent.{Executors, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future, Promise} +import scala.concurrent.duration._ +import scala.io.Source +import scala.language.postfixOps +import scala.util.Try + +import com.google.common.io.Files +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.scalatest.{BeforeAndAfter, Matchers} +import org.scalatest.exceptions.TestFailedDueToTimeoutException + +import org.apache.spark.{HostState, SparkConf, SparkContext} +import org.apache.spark.internal.Logging +import org.apache.spark.tags.ExtendedYarnTest +import org.apache.spark.util.ThreadUtils + +/** + * Integration test for YARN's graceful decommission mechanism; these tests use a mini + * Yarn cluster to run Spark-on-YARN applications, and require the Spark assembly to be built + * before they can be successfully run. + * Tests trigger the decommission of the only node in the mini Yarn cluster, and then check + * in the Yarn container logs that the Yarn node transitions were received at the driver. + */ +@ExtendedYarnTest +class YarnDecommissioningSuite extends BaseYarnClusterSuite with BeforeAndAfter { + + private val (excludedHostsFile, syncFile) = { + val (excludedHostsFile, syncFile) = (File.createTempFile("yarn-excludes", null, tempDir), + File.createTempFile("syncFile", null, tempDir)) + excludedHostsFile.deleteOnExit() + syncFile.deleteOnExit() + logInfo(s"Using YARN excludes file ${excludedHostsFile.getAbsolutePath}") + logInfo(s"Using sync file ${syncFile.getAbsolutePath}") + (excludedHostsFile, syncFile) + } + // used to avoid restarting the MiniYARNCluster on the first test run + private var fistTestRun = true + private val executorService = Executors.newSingleThreadScheduledExecutor() + private implicit val ec = ExecutionContext.fromExecutorService(executorService) + + override val newYarnConfig: YarnConfiguration = { + val conf = new YarnConfiguration() + conf.set("yarn.resourcemanager.nodes.exclude-path", excludedHostsFile.getAbsolutePath) + conf + } + + private val decommissionStates = Set(HostState.Decommissioning, + HostState.Decommissioned).map{ state => + val yarnStateOpt = HostState.toYarnState(state) + assert(yarnStateOpt.isDefined, + s"Spark host state $state should have a translation to YARN state") + yarnStateOpt.get + } + + before { + if (!fistTestRun) { + Files.write("", excludedHostsFile, StandardCharsets.UTF_8) + Files.write("", syncFile, StandardCharsets.UTF_8) + restartCluster() + } + fistTestRun = false + } + + test("Spark application master gets notified on node decommissioning when running in" + + " cluster mode") { + val excludedHostStateUpdates = testNodeDecommission(clientMode = false) + excludedHostStateUpdates shouldEqual(decommissionStates) + } + + test("Spark application master gets notified on node decommissioning when running in" + + " client mode") { + val excludedHostStateUpdates = testNodeDecommission(clientMode = true) + // In client mode the node doesn't always have time to reach the decommissioned state + assert(excludedHostStateUpdates.subsetOf(decommissionStates)) + assert(excludedHostStateUpdates + .contains(HostState.toYarnState(HostState.Decommissioning).get)) + } + + /** + * @return a set of strings for the Yarn decommission related states the only node in + * the MiniYARNCluster has transitioned to after the Spark job has started. + */ + private def testNodeDecommission(clientMode: Boolean): Set[String] = { + val excludedHostPromise = Promise[String] + scheduleDecommissionRunnable(excludedHostPromise) + + // surface exceptions in the executor service + val excludedHostFuture = excludedHostPromise.future + excludedHostFuture.onFailure { case t => throw t } + // we expect a timeout exception because the job will fail when the only available node + // is decommissioned after its timeout + intercept[TestFailedDueToTimeoutException] { + runSpark(clientMode, mainClassName(YarnDecommissioningDriver.getClass), + appArgs = Seq(syncFile.getAbsolutePath), + extraConf = Map(), + numExecutors = 2, + executionTimeout = 2 minutes) + } + assert(excludedHostPromise.isCompleted, "graceful decommission was not launched for any node") + val excludedHost = ThreadUtils.awaitResult(excludedHostFuture, 1 millisecond) + assert(excludedHost.length > 0) + getExcludedHostStateUpdate(excludedHost) + } + + /** + * This method repeatedly schedules a task that checks the contents of the syncFile used to + * synchronize with the Spark driver. When the syncFile is updated with the sync text then + * YARN's graceful decommission mechanism is triggered, and the excluded host is returned + * by completing excludedHostPromise. + */ + private def scheduleDecommissionRunnable(excludedHostPromise: Promise[String]): Unit = { + def decommissionRunnable(): Runnable = new Runnable() { + override def run() { + if (syncFile.exists() && + Files.toString(syncFile, StandardCharsets.UTF_8) + .equals(YarnDecommissioningDriver.SYNC_TEXT)) { + excludedHostPromise.complete(Try{ + logInfo("Launching graceful decommission of a node in YARN") + gracefullyDecommissionNode(newYarnConfig, excludedHostsFile, + decommissionTimeout = 10 seconds) + }) + } else { + logDebug("Waiting for sync file to be updated by the driver") + executorService.schedule(decommissionRunnable(), 100, TimeUnit.MILLISECONDS) + } + } + } + executorService.schedule(decommissionRunnable(), 1, TimeUnit.SECONDS) + } + + /** + * This method should be called after the Spark application has completed, to parse + * the container logs for messages about Yarn decommission related states involving + * the node that was decommissioned. + */ + private def getExcludedHostStateUpdate(excludedHost: String): Set[String] = { + val stateChangeRe = { + val decommissionStateRe = decommissionStates.mkString("|") + "(?:%s.*(%s))|(?:(%s).*%s)".format(excludedHost, decommissionStateRe, + decommissionStateRe, excludedHost).r + } + (for { + file <- JFiles.walk(Paths.get(getClusterWorkDir.getAbsolutePath)) + .iterator().asScala + if file.getFileName.toString == "stderr" + line <- Source.fromFile(file.toFile).getLines() + matchingSubgroups <- stateChangeRe.findFirstMatchIn(line) + .map(_.subgroups.filter(_ != null)).toSeq + group <- matchingSubgroups + } yield group).toSet + } +} + +private object YarnDecommissioningDriver extends Logging with Matchers { + + val SYNC_TEXT = "First action completed. Start decommissioning." + val WAIT_TIMEOUT_MILLIS = 10000 + val DECOMMISSION_WAIT_TIME_MILLIS = 500000 + + def main(args: Array[String]): Unit = { + if (args.length != 1) { + // scalastyle:off println + System.err.println( + s""" + |Invalid command line: ${args.mkString(" ")} + | + |Usage: YarnDecommissioningDriver [sync file] + """.stripMargin) + // scalastyle:on println + System.exit(1) + } + val sc = new SparkContext(new SparkConf() + .setAppName("Yarn Decommissioning Test")) + try { + logInfo("Starting YarnDecommissioningDriver") + val counts = sc.parallelize(1 to 10, 4) + .map{ x => (x%7, x)} + .reduceByKey(_ + _).collect + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + logInfo(s"Got ${counts.mkString(",")}") + + val syncFile = new File(args(0)) + Files.append(SYNC_TEXT, syncFile, StandardCharsets.UTF_8) + logInfo(s"Sync file ${syncFile} written") + + // Wait for decommissioning and then for decommissioned, the timeout in + // the corresponding call to runSpark will interrupt this + Thread.sleep(DECOMMISSION_WAIT_TIME_MILLIS) + } catch { + case e => + logError(s"Driver exception: ${e.getMessage}") + } + finally { + sc.stop() + } + } +} +