diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 18305ad3746a6..c8c6e5a192a24 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -108,6 +108,13 @@ private[deploy] object DeployMessages { case class Heartbeat(workerId: String, worker: RpcEndpointRef) extends DeployMessage + /** + * Used by the MasterWebUI to request the master to decommission all workers that are active on + * any of the given hostnames. + * @param hostnames: A list of hostnames without the ports. Like "localhost", "foo.bar.com" etc + */ + case class DecommissionWorkersOnHosts(hostnames: Seq[String]) + // Master to Worker sealed trait RegisterWorkerResponse diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index d2e65db970380..0070df1d66dee 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -22,7 +22,9 @@ import java.util.{Date, Locale} import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.mutable import scala.util.Random +import scala.util.control.NonFatal import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, SparkHadoopUtil} @@ -525,6 +527,13 @@ private[deploy] class Master( case KillExecutors(appId, executorIds) => val formattedExecutorIds = formatExecutorIds(executorIds) context.reply(handleKillExecutors(appId, formattedExecutorIds)) + + case DecommissionWorkersOnHosts(hostnames) => + if (state != RecoveryState.STANDBY) { + context.reply(decommissionWorkersOnHosts(hostnames)) + } else { + context.reply(0) + } } override def onDisconnected(address: RpcAddress): Unit = { @@ -863,6 +872,34 @@ private[deploy] class Master( true } + /** + * Decommission all workers that are active on any of the given hostnames. The decommissioning is + * asynchronously done by enqueueing WorkerDecommission messages to self. No checks are done about + * the prior state of the worker. So an already decommissioned worker will match as well. + * + * @param hostnames: A list of hostnames without the ports. Like "localhost", "foo.bar.com" etc + * + * Returns the number of workers that matched the hostnames. + */ + private def decommissionWorkersOnHosts(hostnames: Seq[String]): Integer = { + val hostnamesSet = hostnames.map(_.toLowerCase(Locale.ROOT)).toSet + val workersToRemove = addressToWorker + .filterKeys(addr => hostnamesSet.contains(addr.host.toLowerCase(Locale.ROOT))) + .values + + val workersToRemoveHostPorts = workersToRemove.map(_.hostPort) + logInfo(s"Decommissioning the workers with host:ports ${workersToRemoveHostPorts}") + + // The workers are removed async to avoid blocking the receive loop for the entire batch + workersToRemove.foreach(wi => { + logInfo(s"Sending the worker decommission to ${wi.id} and ${wi.endpoint}") + self.send(WorkerDecommission(wi.id, wi.endpoint)) + }) + + // Return the count of workers actually removed + workersToRemove.size + } + private def decommissionWorker(worker: WorkerInfo): Unit = { if (worker.state != WorkerState.DECOMMISSIONED) { logInfo("Decommissioning worker %s on %s:%d".format(worker.id, worker.host, worker.port)) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 86554ec4ec1c9..035f9d379471c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -17,9 +17,14 @@ package org.apache.spark.deploy.master.ui -import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} +import java.net.{InetAddress, NetworkInterface, SocketException} +import java.util.Locale +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} + +import org.apache.spark.deploy.DeployMessages.{DecommissionWorkersOnHosts, MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.Master import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.UI.MASTER_UI_DECOMMISSION_ALLOW_MODE import org.apache.spark.internal.config.UI.UI_KILL_ENABLED import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ @@ -36,6 +41,7 @@ class MasterWebUI( val masterEndpointRef = master.self val killEnabled = master.conf.get(UI_KILL_ENABLED) + val decommissionAllowMode = master.conf.get(MASTER_UI_DECOMMISSION_ALLOW_MODE) initialize() @@ -49,6 +55,27 @@ class MasterWebUI( "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) attachHandler(createRedirectHandler( "/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethods = Set("POST"))) + attachHandler(createServletHandler("/workers/kill", new HttpServlet { + override def doPost(req: HttpServletRequest, resp: HttpServletResponse): Unit = { + val hostnames: Seq[String] = Option(req.getParameterValues("host")) + .getOrElse(Array[String]()).toSeq + if (!isDecommissioningRequestAllowed(req)) { + resp.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } else { + val removedWorkers = masterEndpointRef.askSync[Integer]( + DecommissionWorkersOnHosts(hostnames)) + logInfo(s"Decommissioning of hosts $hostnames decommissioned $removedWorkers workers") + if (removedWorkers > 0) { + resp.setStatus(HttpServletResponse.SC_OK) + } else if (removedWorkers == 0) { + resp.sendError(HttpServletResponse.SC_NOT_FOUND) + } else { + // We shouldn't even see this case. + resp.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + } + } + } + }, "")) } def addProxy(): Unit = { @@ -64,6 +91,25 @@ class MasterWebUI( maybeWorkerUiAddress.orElse(maybeAppUiAddress) } + private def isLocal(address: InetAddress): Boolean = { + if (address.isAnyLocalAddress || address.isLoopbackAddress) { + return true + } + try { + NetworkInterface.getByInetAddress(address) != null + } catch { + case _: SocketException => false + } + } + + private def isDecommissioningRequestAllowed(req: HttpServletRequest): Boolean = { + decommissionAllowMode match { + case "ALLOW" => true + case "LOCAL" => isLocal(InetAddress.getByName(req.getRemoteAddr)) + case _ => false + } + } + } private[master] object MasterWebUI { diff --git a/core/src/main/scala/org/apache/spark/internal/config/UI.scala b/core/src/main/scala/org/apache/spark/internal/config/UI.scala index 231eecf086bbe..fcbe2b9775841 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/UI.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/UI.scala @@ -17,6 +17,7 @@ package org.apache.spark.internal.config +import java.util.Locale import java.util.concurrent.TimeUnit import org.apache.spark.network.util.ByteUnit @@ -191,4 +192,15 @@ private[spark] object UI { .version("3.0.0") .stringConf .createOptional + + val MASTER_UI_DECOMMISSION_ALLOW_MODE = ConfigBuilder("spark.master.ui.decommission.allow.mode") + .doc("Specifies the behavior of the Master Web UI's /workers/kill endpoint. Possible choices" + + " are: `LOCAL` means allow this endpoint from IP's that are local to the machine running" + + " the Master, `DENY` means to completely disable this endpoint, `ALLOW` means to allow" + + " calling this endpoint from any IP.") + .internal() + .version("3.1.0") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .createWithDefault("LOCAL") } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 91128af82b022..d98a6b29be9e8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.mutable.{HashMap, HashSet} import scala.concurrent.duration._ import scala.io.Source import scala.reflect.ClassTag @@ -726,6 +726,65 @@ class MasterSuite extends SparkFunSuite } } + def testWorkerDecommissioning( + numWorkers: Int, + numWorkersExpectedToDecom: Int, + hostnames: Seq[String]): Unit = { + val conf = new SparkConf() + val master = makeAliveMaster(conf) + val workerRegs = (1 to numWorkers).map{idx => + val worker = new MockWorker(master.self, conf) + worker.rpcEnv.setupEndpoint("worker", worker) + val workerReg = RegisterWorker( + worker.id, + "localhost", + worker.self.address.port, + worker.self, + 10, + 1024, + "http://localhost:8080", + RpcAddress("localhost", 10000)) + master.self.send(workerReg) + workerReg + } + + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + assert(masterState.workers.length === numWorkers) + assert(masterState.workers.forall(_.state == WorkerState.ALIVE)) + assert(masterState.workers.map(_.id).toSet == workerRegs.map(_.id).toSet) + } + + val decomWorkersCount = master.self.askSync[Integer](DecommissionWorkersOnHosts(hostnames)) + assert(decomWorkersCount === numWorkersExpectedToDecom) + + // Decommissioning is actually async ... wait for the workers to actually be decommissioned by + // polling the master's state. + eventually(timeout(30.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + assert(masterState.workers.length === numWorkers) + val workersActuallyDecomed = masterState.workers.count(_.state == WorkerState.DECOMMISSIONED) + assert(workersActuallyDecomed === numWorkersExpectedToDecom) + } + + // Decommissioning a worker again should return the same answer since we want this call to be + // idempotent. + val decomWorkersCountAgain = master.self.askSync[Integer](DecommissionWorkersOnHosts(hostnames)) + assert(decomWorkersCountAgain === numWorkersExpectedToDecom) + } + + test("All workers on a host should be decommissioned") { + testWorkerDecommissioning(2, 2, Seq("LoCalHost", "localHOST")) + } + + test("No workers should be decommissioned with invalid host") { + testWorkerDecommissioning(2, 0, Seq("NoSuchHost1", "NoSuchHost2")) + } + + test("Only worker on host should be decommissioned") { + testWorkerDecommissioning(1, 1, Seq("lOcalHost", "NoSuchHost")) + } + test("SPARK-19900: there should be a corresponding driver for the app after relaunching driver") { val conf = new SparkConf().set(WORKER_TIMEOUT, 1L) val master = makeAliveMaster(conf) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala index e2d7facdd77e0..35de457ec48ce 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala @@ -21,6 +21,7 @@ import java.io.DataOutputStream import java.net.{HttpURLConnection, URL} import java.nio.charset.StandardCharsets import java.util.Date +import javax.servlet.http.HttpServletResponse import scala.collection.mutable.HashMap @@ -28,15 +29,16 @@ import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver} +import org.apache.spark.deploy.DeployMessages.{DecommissionWorkersOnHosts, KillDriverResponse, RequestKillDriver} import org.apache.spark.deploy.DeployTestUtils._ import org.apache.spark.deploy.master._ +import org.apache.spark.internal.config.UI import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll { - val conf = new SparkConf + val conf = new SparkConf() val securityMgr = new SecurityManager(conf) val rpcEnv = mock(classOf[RpcEnv]) val master = mock(classOf[Master]) @@ -88,10 +90,32 @@ class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll { verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId)) } - private def convPostDataToString(data: Map[String, String]): String = { + private def testKillWorkers(hostnames: Seq[String]): Unit = { + val url = s"http://localhost:${masterWebUI.boundPort}/workers/kill/" + val body = convPostDataToString(hostnames.map(("host", _))) + val conn = sendHttpRequest(url, "POST", body) + // The master is mocked here, so cannot assert on the response code + conn.getResponseCode + // Verify that master was asked to kill driver with the correct id + verify(masterEndpointRef).askSync[Integer](DecommissionWorkersOnHosts(hostnames)) + } + + test("Kill one host") { + testKillWorkers(Seq("localhost")) + } + + test("Kill multiple hosts") { + testKillWorkers(Seq("noSuchHost", "LocalHost")) + } + + private def convPostDataToString(data: Seq[(String, String)]): String = { (for ((name, value) <- data) yield s"$name=$value").mkString("&") } + private def convPostDataToString(data: Map[String, String]): String = { + convPostDataToString(data.toSeq) + } + /** * Send an HTTP request to the given URL using the method and the body specified. * Return the connection object.