Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions core/src/main/scala/org/apache/spark/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,22 @@
package org.apache.spark

import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
import java.net.{URI, URL}
import java.net.{HttpURLConnection, URI, URL}
import java.nio.charset.StandardCharsets
import java.nio.file.Paths
import java.security.SecureRandom
import java.security.cert.X509Certificate
import java.util.Arrays
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.jar.{JarEntry, JarOutputStream}
import javax.net.ssl._
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import com.google.common.io.{ByteStreams, Files}
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
Expand Down Expand Up @@ -182,6 +185,37 @@ private[spark] object TestUtils {
assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did")
}

/**
* Returns the response code from an HTTP(S) URL.
*/
def httpResponseCode(url: URL, method: String = "GET"): Int = {
val connection = url.openConnection().asInstanceOf[HttpURLConnection]
connection.setRequestMethod(method)

// Disable cert and host name validation for HTTPS tests.
if (connection.isInstanceOf[HttpsURLConnection]) {
val sslCtx = SSLContext.getInstance("SSL")
val trustManager = new X509TrustManager {
override def getAcceptedIssuers(): Array[X509Certificate] = null
override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {}
override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {}
}
val verifier = new HostnameVerifier() {
override def verify(hostname: String, session: SSLSession): Boolean = true
}
sslCtx.init(null, Array(trustManager), new SecureRandom())
connection.asInstanceOf[HttpsURLConnection].setSSLSocketFactory(sslCtx.getSocketFactory())
connection.asInstanceOf[HttpsURLConnection].setHostnameVerifier(verifier)
}

try {
connection.connect()
connection.getResponseCode()
} finally {
connection.disconnect()
}
}

}


Expand Down
87 changes: 61 additions & 26 deletions core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ import org.apache.spark.util.Utils
*/
private[spark] object JettyUtils extends Logging {

val SPARK_CONNECTOR_NAME = "Spark"
val REDIRECT_CONNECTOR_NAME = "HttpsRedirect"

// Base type for a function that returns something based on an HTTP request. Allows for
// implicit conversion from many types of functions to jetty Handlers.
type Responder[T] = HttpServletRequest => T
Expand Down Expand Up @@ -274,25 +277,28 @@ private[spark] object JettyUtils extends Logging {
conf: SparkConf,
serverName: String = ""): ServerInfo = {

val collection = new ContextHandlerCollection
addFilters(handlers, conf)

val gzipHandlers = handlers.map { h =>
h.setVirtualHosts(Array("@" + SPARK_CONNECTOR_NAME))

val gzipHandler = new GzipHandler
gzipHandler.setHandler(h)
gzipHandler
}

// Bind to the given port, or throw a java.net.BindException if the port is occupied
def connect(currentPort: Int): (Server, Int) = {
def connect(currentPort: Int): ((Server, Option[Int]), Int) = {
val pool = new QueuedThreadPool
if (serverName.nonEmpty) {
pool.setName(serverName)
}
pool.setDaemon(true)

val server = new Server(pool)
val connectors = new ArrayBuffer[ServerConnector]
val connectors = new ArrayBuffer[ServerConnector]()
val collection = new ContextHandlerCollection

// Create a connector on port currentPort to listen for HTTP requests
val httpConnector = new ServerConnector(
server,
Expand All @@ -306,26 +312,33 @@ private[spark] object JettyUtils extends Logging {
httpConnector.setPort(currentPort)
connectors += httpConnector

sslOptions.createJettySslContextFactory().foreach { factory =>
// If the new port wraps around, do not try a privileged port.
val securePort =
if (currentPort != 0) {
(currentPort + 400 - 1024) % (65536 - 1024) + 1024
} else {
0
}
val scheme = "https"
// Create a connector on port securePort to listen for HTTPS requests
val connector = new ServerConnector(server, factory)
connector.setPort(securePort)

connectors += connector

// redirect the HTTP requests to HTTPS port
collection.addHandler(createRedirectHttpsHandler(securePort, scheme))
val httpsConnector = sslOptions.createJettySslContextFactory() match {
case Some(factory) =>
// If the new port wraps around, do not try a privileged port.
val securePort =
if (currentPort != 0) {
(currentPort + 400 - 1024) % (65536 - 1024) + 1024
} else {
0
}
val scheme = "https"
// Create a connector on port securePort to listen for HTTPS requests
val connector = new ServerConnector(server, factory)
connector.setPort(securePort)
connector.setName(SPARK_CONNECTOR_NAME)
connectors += connector

// redirect the HTTP requests to HTTPS port
httpConnector.setName(REDIRECT_CONNECTOR_NAME)
collection.addHandler(createRedirectHttpsHandler(securePort, scheme))
Some(connector)

case None =>
// No SSL, so the HTTP connector becomes the official one where all contexts bind.
httpConnector.setName(SPARK_CONNECTOR_NAME)
None
}

gzipHandlers.foreach(collection.addHandler)
// As each acceptor and each selector will use one thread, the number of threads should at
// least be the number of acceptors and selectors plus 1. (See SPARK-13776)
var minThreads = 1
Expand All @@ -337,17 +350,20 @@ private[spark] object JettyUtils extends Logging {
// The number of selectors always equals to the number of acceptors
minThreads += connector.getAcceptors * 2
}
server.setConnectors(connectors.toArray)
pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads))

val errorHandler = new ErrorHandler()
errorHandler.setShowStacks(true)
errorHandler.setServer(server)
server.addBean(errorHandler)

gzipHandlers.foreach(collection.addHandler)
server.setHandler(collection)

server.setConnectors(connectors.toArray)
try {
server.start()
(server, httpConnector.getLocalPort)
((server, httpsConnector.map(_.getLocalPort())), httpConnector.getLocalPort)
} catch {
case e: Exception =>
server.stop()
Expand All @@ -356,13 +372,16 @@ private[spark] object JettyUtils extends Logging {
}
}

val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName)
ServerInfo(server, boundPort, collection)
val ((server, securePort), boundPort) = Utils.startServiceOnPort(port, connect, conf,
serverName)
ServerInfo(server, boundPort, securePort,
server.getHandler().asInstanceOf[ContextHandlerCollection])
}

private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = {
val redirectHandler: ContextHandler = new ContextHandler
redirectHandler.setContextPath("/")
redirectHandler.setVirtualHosts(Array("@" + REDIRECT_CONNECTOR_NAME))
redirectHandler.setHandler(new AbstractHandler {
override def handle(
target: String,
Expand Down Expand Up @@ -442,7 +461,23 @@ private[spark] object JettyUtils extends Logging {
private[spark] case class ServerInfo(
server: Server,
boundPort: Int,
rootHandler: ContextHandlerCollection) {
securePort: Option[Int],
private val rootHandler: ContextHandlerCollection) {

def addHandler(handler: ContextHandler): Unit = {
handler.setVirtualHosts(Array("@" + JettyUtils.SPARK_CONNECTOR_NAME))
rootHandler.addHandler(handler)
if (!handler.isStarted()) {
handler.start()
}
}

def removeHandler(handler: ContextHandler): Unit = {
rootHandler.removeHandler(handler)
if (handler.isStarted) {
handler.stop()
}
}

def stop(): Unit = {
server.stop()
Expand Down
14 changes: 2 additions & 12 deletions core/src/main/scala/org/apache/spark/ui/WebUI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,23 +91,13 @@ private[spark] abstract class WebUI(
/** Attach a handler to this UI. */
def attachHandler(handler: ServletContextHandler) {
handlers += handler
serverInfo.foreach { info =>
info.rootHandler.addHandler(handler)
if (!handler.isStarted) {
handler.start()
}
}
serverInfo.foreach(_.addHandler(handler))
}

/** Detach a handler from this UI. */
def detachHandler(handler: ServletContextHandler) {
handlers -= handler
serverInfo.foreach { info =>
info.rootHandler.removeHandler(handler)
if (handler.isStarted) {
handler.stop()
}
}
serverInfo.foreach(_.removeHandler(handler))
}

/**
Expand Down
19 changes: 4 additions & 15 deletions core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
val url = new URL(
sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0")
// SPARK-6846: should be POST only but YARN AM doesn't proxy POST
getResponseCode(url, "GET") should be (200)
getResponseCode(url, "POST") should be (200)
TestUtils.httpResponseCode(url, "GET") should be (200)
TestUtils.httpResponseCode(url, "POST") should be (200)
}
}
}
Expand All @@ -488,8 +488,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
val url = new URL(
sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/kill/?id=0")
// SPARK-6846: should be POST only but YARN AM doesn't proxy POST
getResponseCode(url, "GET") should be (200)
getResponseCode(url, "POST") should be (200)
TestUtils.httpResponseCode(url, "GET") should be (200)
TestUtils.httpResponseCode(url, "POST") should be (200)
}
}
}
Expand Down Expand Up @@ -671,17 +671,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
}
}

def getResponseCode(url: URL, method: String): Int = {
val connection = url.openConnection().asInstanceOf[HttpURLConnection]
connection.setRequestMethod(method)
try {
connection.connect()
connection.getResponseCode()
} finally {
connection.disconnect()
}
}

def goToUi(sc: SparkContext, path: String): Unit = {
goToUi(sc.ui.get, path)
}
Expand Down
56 changes: 52 additions & 4 deletions core/src/test/scala/org/apache/spark/ui/UISuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
package org.apache.spark.ui

import java.net.{BindException, ServerSocket}
import java.net.URI
import javax.servlet.http.HttpServletRequest
import java.net.{URI, URL}
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}

import scala.io.Source

import org.eclipse.jetty.servlet.ServletContextHandler
import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
import org.mockito.Mockito.{mock, when}
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
Expand Down Expand Up @@ -167,6 +167,7 @@ class UISuite extends SparkFunSuite {
val boundPort = serverInfo.boundPort
assert(server.getState === "STARTED")
assert(boundPort != 0)
assert(serverInfo.securePort.isDefined)
intercept[BindException] {
socket = new ServerSocket(boundPort)
}
Expand Down Expand Up @@ -228,8 +229,55 @@ class UISuite extends SparkFunSuite {
assert(newHeader === null)
}

test("http -> https redirect applies to all URIs") {
var serverInfo: ServerInfo = null
try {
val servlet = new HttpServlet() {
override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = {
res.sendError(HttpServletResponse.SC_OK)
}
}

def newContext(path: String): ServletContextHandler = {
val ctx = new ServletContextHandler()
ctx.setContextPath(path)
ctx.addServlet(new ServletHolder(servlet), "/root")
ctx
}

val (conf, sslOptions) = sslEnabledConf()
serverInfo = JettyUtils.startJettyServer("0.0.0.0", 0, sslOptions,
Seq[ServletContextHandler](newContext("/"), newContext("/test1")),
conf)
assert(serverInfo.server.getState === "STARTED")

val testContext = newContext("/test2")
serverInfo.addHandler(testContext)
testContext.start()

val httpPort = serverInfo.boundPort

val tests = Seq(
("http", serverInfo.boundPort, HttpServletResponse.SC_FOUND),
("https", serverInfo.securePort.get, HttpServletResponse.SC_OK))

tests.foreach { case (scheme, port, expected) =>
val urls = Seq(
s"$scheme://localhost:$port/root",
s"$scheme://localhost:$port/test1/root",
s"$scheme://localhost:$port/test2/root")
urls.foreach { url =>
val rc = TestUtils.httpResponseCode(new URL(url))
assert(rc === expected, s"Unexpected status $rc for $url")
}
}
} finally {
stopServer(serverInfo)
}
}

def stopServer(info: ServerInfo): Unit = {
if (info != null && info.server != null) info.server.stop
if (info != null) info.stop()
}

def closeSocket(socket: ServerSocket): Unit = {
Expand Down