Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ private[deploy] object DeployMessages {

case class Heartbeat(workerId: String, worker: RpcEndpointRef) extends DeployMessage

/**
* Used by the MasterWebUI to request the master to decommission all workers that are active on
* any of the given hostnames.
* @param hostnames: A list of hostnames without the ports. Like "localhost", "foo.bar.com" etc
*/
case class DecommissionWorkersOnHosts(hostnames: Seq[String])

// Master to Worker

sealed trait RegisterWorkerResponse
Expand Down
37 changes: 37 additions & 0 deletions core/src/main/scala/org/apache/spark/deploy/master/Master.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import java.util.{Date, Locale}
import java.util.concurrent.{ScheduledFuture, TimeUnit}

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.collection.mutable
import scala.util.Random
import scala.util.control.NonFatal

import org.apache.spark.{SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, SparkHadoopUtil}
Expand Down Expand Up @@ -525,6 +527,13 @@ private[deploy] class Master(
case KillExecutors(appId, executorIds) =>
val formattedExecutorIds = formatExecutorIds(executorIds)
context.reply(handleKillExecutors(appId, formattedExecutorIds))

case DecommissionWorkersOnHosts(hostnames) =>
if (state != RecoveryState.STANDBY) {
context.reply(decommissionWorkersOnHosts(hostnames))
} else {
context.reply(0)
}
}

override def onDisconnected(address: RpcAddress): Unit = {
Expand Down Expand Up @@ -863,6 +872,34 @@ private[deploy] class Master(
true
}

/**
* Decommission all workers that are active on any of the given hostnames. The decommissioning is
* asynchronously done by enqueueing WorkerDecommission messages to self. No checks are done about
* the prior state of the worker. So an already decommissioned worker will match as well.
*
* @param hostnames: A list of hostnames without the ports. Like "localhost", "foo.bar.com" etc
*
* Returns the number of workers that matched the hostnames.
*/
private def decommissionWorkersOnHosts(hostnames: Seq[String]): Integer = {
val hostnamesSet = hostnames.map(_.toLowerCase(Locale.ROOT)).toSet
val workersToRemove = addressToWorker
.filterKeys(addr => hostnamesSet.contains(addr.host.toLowerCase(Locale.ROOT)))
.values

val workersToRemoveHostPorts = workersToRemove.map(_.hostPort)
logInfo(s"Decommissioning the workers with host:ports ${workersToRemoveHostPorts}")

// The workers are removed async to avoid blocking the receive loop for the entire batch
workersToRemove.foreach(wi => {
logInfo(s"Sending the worker decommission to ${wi.id} and ${wi.endpoint}")
self.send(WorkerDecommission(wi.id, wi.endpoint))
})

// Return the count of workers actually removed
workersToRemove.size
}

private def decommissionWorker(worker: WorkerInfo): Unit = {
if (worker.state != WorkerState.DECOMMISSIONED) {
logInfo("Decommissioning worker %s on %s:%d".format(worker.id, worker.host, worker.port))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

package org.apache.spark.deploy.master.ui

import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
import java.net.{InetAddress, NetworkInterface, SocketException}
import java.util.Locale
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}

import org.apache.spark.deploy.DeployMessages.{DecommissionWorkersOnHosts, MasterStateResponse, RequestMasterState}
import org.apache.spark.deploy.master.Master
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.UI.MASTER_UI_DECOMMISSION_ALLOW_MODE
import org.apache.spark.internal.config.UI.UI_KILL_ENABLED
import org.apache.spark.ui.{SparkUI, WebUI}
import org.apache.spark.ui.JettyUtils._
Expand All @@ -36,6 +41,7 @@ class MasterWebUI(

val masterEndpointRef = master.self
val killEnabled = master.conf.get(UI_KILL_ENABLED)
val decommissionAllowMode = master.conf.get(MASTER_UI_DECOMMISSION_ALLOW_MODE)

initialize()

Expand All @@ -49,6 +55,27 @@ class MasterWebUI(
"/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST")))
attachHandler(createRedirectHandler(
"/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethods = Set("POST")))
attachHandler(createServletHandler("/workers/kill", new HttpServlet {
override def doPost(req: HttpServletRequest, resp: HttpServletResponse): Unit = {
val hostnames: Seq[String] = Option(req.getParameterValues("host"))
.getOrElse(Array[String]()).toSeq
if (!isDecommissioningRequestAllowed(req)) {
resp.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED)
} else {
val removedWorkers = masterEndpointRef.askSync[Integer](
DecommissionWorkersOnHosts(hostnames))
logInfo(s"Decommissioning of hosts $hostnames decommissioned $removedWorkers workers")
if (removedWorkers > 0) {
resp.setStatus(HttpServletResponse.SC_OK)
} else if (removedWorkers == 0) {
resp.sendError(HttpServletResponse.SC_NOT_FOUND)
} else {
// We shouldn't even see this case.
resp.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
}
}
}
}, ""))
}

def addProxy(): Unit = {
Expand All @@ -64,6 +91,25 @@ class MasterWebUI(
maybeWorkerUiAddress.orElse(maybeAppUiAddress)
}

private def isLocal(address: InetAddress): Boolean = {
if (address.isAnyLocalAddress || address.isLoopbackAddress) {
return true
}
try {
NetworkInterface.getByInetAddress(address) != null
} catch {
case _: SocketException => false
}
}

private def isDecommissioningRequestAllowed(req: HttpServletRequest): Boolean = {
decommissionAllowMode match {
case "ALLOW" => true
case "LOCAL" => isLocal(InetAddress.getByName(req.getRemoteAddr))
case _ => false
}
}

}

private[master] object MasterWebUI {
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/UI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.internal.config

import java.util.Locale
import java.util.concurrent.TimeUnit

import org.apache.spark.network.util.ByteUnit
Expand Down Expand Up @@ -191,4 +192,15 @@ private[spark] object UI {
.version("3.0.0")
.stringConf
.createOptional

val MASTER_UI_DECOMMISSION_ALLOW_MODE = ConfigBuilder("spark.master.ui.decommission.allow.mode")
.doc("Specifies the behavior of the Master Web UI's /workers/kill endpoint. Possible choices" +
" are: `LOCAL` means allow this endpoint from IP's that are local to the machine running" +
" the Master, `DENY` means to completely disable this endpoint, `ALLOW` means to allow" +
" calling this endpoint from any IP.")
.internal()
.version("3.1.0")
.stringConf
.transform(_.toUpperCase(Locale.ROOT))
.createWithDefault("LOCAL")
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.collection.mutable.{HashMap, HashSet}
import scala.concurrent.duration._
import scala.io.Source
import scala.reflect.ClassTag
Expand Down Expand Up @@ -726,6 +726,65 @@ class MasterSuite extends SparkFunSuite
}
}

def testWorkerDecommissioning(
numWorkers: Int,
numWorkersExpectedToDecom: Int,
hostnames: Seq[String]): Unit = {
val conf = new SparkConf()
val master = makeAliveMaster(conf)
val workerRegs = (1 to numWorkers).map{idx =>
val worker = new MockWorker(master.self, conf)
worker.rpcEnv.setupEndpoint("worker", worker)
val workerReg = RegisterWorker(
worker.id,
"localhost",
worker.self.address.port,
worker.self,
10,
1024,
"http://localhost:8080",
RpcAddress("localhost", 10000))
master.self.send(workerReg)
workerReg
}

eventually(timeout(10.seconds)) {
val masterState = master.self.askSync[MasterStateResponse](RequestMasterState)
assert(masterState.workers.length === numWorkers)
assert(masterState.workers.forall(_.state == WorkerState.ALIVE))
assert(masterState.workers.map(_.id).toSet == workerRegs.map(_.id).toSet)
}

val decomWorkersCount = master.self.askSync[Integer](DecommissionWorkersOnHosts(hostnames))
assert(decomWorkersCount === numWorkersExpectedToDecom)

// Decommissioning is actually async ... wait for the workers to actually be decommissioned by
// polling the master's state.
eventually(timeout(30.seconds)) {
val masterState = master.self.askSync[MasterStateResponse](RequestMasterState)
assert(masterState.workers.length === numWorkers)
val workersActuallyDecomed = masterState.workers.count(_.state == WorkerState.DECOMMISSIONED)
assert(workersActuallyDecomed === numWorkersExpectedToDecom)
}

// Decommissioning a worker again should return the same answer since we want this call to be
// idempotent.
val decomWorkersCountAgain = master.self.askSync[Integer](DecommissionWorkersOnHosts(hostnames))
assert(decomWorkersCountAgain === numWorkersExpectedToDecom)
}

test("All workers on a host should be decommissioned") {
testWorkerDecommissioning(2, 2, Seq("LoCalHost", "localHOST"))
}

test("No workers should be decommissioned with invalid host") {
testWorkerDecommissioning(2, 0, Seq("NoSuchHost1", "NoSuchHost2"))
}

test("Only worker on host should be decommissioned") {
testWorkerDecommissioning(1, 1, Seq("lOcalHost", "NoSuchHost"))
}

test("SPARK-19900: there should be a corresponding driver for the app after relaunching driver") {
val conf = new SparkConf().set(WORKER_TIMEOUT, 1L)
val master = makeAliveMaster(conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,24 @@ import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date
import javax.servlet.http.HttpServletResponse

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployMessages.{DecommissionWorkersOnHosts, KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.internal.config.UI
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

val conf = new SparkConf
val conf = new SparkConf()
val securityMgr = new SecurityManager(conf)
val rpcEnv = mock(classOf[RpcEnv])
val master = mock(classOf[Master])
Expand Down Expand Up @@ -88,10 +90,32 @@ class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {
verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
}

private def convPostDataToString(data: Map[String, String]): String = {
private def testKillWorkers(hostnames: Seq[String]): Unit = {
val url = s"http://localhost:${masterWebUI.boundPort}/workers/kill/"
val body = convPostDataToString(hostnames.map(("host", _)))
val conn = sendHttpRequest(url, "POST", body)
// The master is mocked here, so cannot assert on the response code
conn.getResponseCode
// Verify that master was asked to kill driver with the correct id
verify(masterEndpointRef).askSync[Integer](DecommissionWorkersOnHosts(hostnames))
}

test("Kill one host") {
testKillWorkers(Seq("localhost"))
}

test("Kill multiple hosts") {
testKillWorkers(Seq("noSuchHost", "LocalHost"))
}

private def convPostDataToString(data: Seq[(String, String)]): String = {
(for ((name, value) <- data) yield s"$name=$value").mkString("&")
}

private def convPostDataToString(data: Map[String, String]): String = {
convPostDataToString(data.toSeq)
}

/**
* Send an HTTP request to the given URL using the method and the body specified.
* Return the connection object.
Expand Down