diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index fd0477541ef0..109104f0a537 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -18,11 +18,15 @@ package org.apache.spark import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} -import java.net.{URI, URL} +import java.net.{HttpURLConnection, URI, URL} import java.nio.charset.StandardCharsets +import java.security.SecureRandom +import java.security.cert.X509Certificate import java.util.Arrays import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} +import javax.net.ssl._ +import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -31,7 +35,6 @@ import scala.sys.process.{Process, ProcessLogger} import scala.util.Try import com.google.common.io.{ByteStreams, Files} -import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -194,6 +197,37 @@ private[spark] object TestUtils { attempt.isSuccess && attempt.get == 0 } + /** + * Returns the response code from an HTTP(S) URL. + */ + def httpResponseCode(url: URL, method: String = "GET"): Int = { + val connection = url.openConnection().asInstanceOf[HttpURLConnection] + connection.setRequestMethod(method) + + // Disable cert and host name validation for HTTPS tests. + if (connection.isInstanceOf[HttpsURLConnection]) { + val sslCtx = SSLContext.getInstance("SSL") + val trustManager = new X509TrustManager { + override def getAcceptedIssuers(): Array[X509Certificate] = null + override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {} + override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {} + } + val verifier = new HostnameVerifier() { + override def verify(hostname: String, session: SSLSession): Boolean = true + } + sslCtx.init(null, Array(trustManager), new SecureRandom()) + connection.asInstanceOf[HttpsURLConnection].setSSLSocketFactory(sslCtx.getSocketFactory()) + connection.asInstanceOf[HttpsURLConnection].setHostnameVerifier(verifier) + } + + try { + connection.connect() + connection.getResponseCode() + } finally { + connection.disconnect() + } + } + } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 35c3c8d00f99..f713619cd7ec 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -45,6 +45,9 @@ import org.apache.spark.util.Utils */ private[spark] object JettyUtils extends Logging { + val SPARK_CONNECTOR_NAME = "Spark" + val REDIRECT_CONNECTOR_NAME = "HttpsRedirect" + // Base type for a function that returns something based on an HTTP request. Allows for // implicit conversion from many types of functions to jetty Handlers. type Responder[T] = HttpServletRequest => T @@ -274,17 +277,18 @@ private[spark] object JettyUtils extends Logging { conf: SparkConf, serverName: String = ""): ServerInfo = { - val collection = new ContextHandlerCollection addFilters(handlers, conf) val gzipHandlers = handlers.map { h => + h.setVirtualHosts(Array("@" + SPARK_CONNECTOR_NAME)) + val gzipHandler = new GzipHandler gzipHandler.setHandler(h) gzipHandler } // Bind to the given port, or throw a java.net.BindException if the port is occupied - def connect(currentPort: Int): (Server, Int) = { + def connect(currentPort: Int): ((Server, Option[Int]), Int) = { val pool = new QueuedThreadPool if (serverName.nonEmpty) { pool.setName(serverName) @@ -292,7 +296,9 @@ private[spark] object JettyUtils extends Logging { pool.setDaemon(true) val server = new Server(pool) - val connectors = new ArrayBuffer[ServerConnector] + val connectors = new ArrayBuffer[ServerConnector]() + val collection = new ContextHandlerCollection + // Create a connector on port currentPort to listen for HTTP requests val httpConnector = new ServerConnector( server, @@ -306,26 +312,33 @@ private[spark] object JettyUtils extends Logging { httpConnector.setPort(currentPort) connectors += httpConnector - sslOptions.createJettySslContextFactory().foreach { factory => - // If the new port wraps around, do not try a privileged port. - val securePort = - if (currentPort != 0) { - (currentPort + 400 - 1024) % (65536 - 1024) + 1024 - } else { - 0 - } - val scheme = "https" - // Create a connector on port securePort to listen for HTTPS requests - val connector = new ServerConnector(server, factory) - connector.setPort(securePort) - - connectors += connector - - // redirect the HTTP requests to HTTPS port - collection.addHandler(createRedirectHttpsHandler(securePort, scheme)) + val httpsConnector = sslOptions.createJettySslContextFactory() match { + case Some(factory) => + // If the new port wraps around, do not try a privileged port. + val securePort = + if (currentPort != 0) { + (currentPort + 400 - 1024) % (65536 - 1024) + 1024 + } else { + 0 + } + val scheme = "https" + // Create a connector on port securePort to listen for HTTPS requests + val connector = new ServerConnector(server, factory) + connector.setPort(securePort) + connector.setName(SPARK_CONNECTOR_NAME) + connectors += connector + + // redirect the HTTP requests to HTTPS port + httpConnector.setName(REDIRECT_CONNECTOR_NAME) + collection.addHandler(createRedirectHttpsHandler(securePort, scheme)) + Some(connector) + + case None => + // No SSL, so the HTTP connector becomes the official one where all contexts bind. + httpConnector.setName(SPARK_CONNECTOR_NAME) + None } - gzipHandlers.foreach(collection.addHandler) // As each acceptor and each selector will use one thread, the number of threads should at // least be the number of acceptors and selectors plus 1. (See SPARK-13776) var minThreads = 1 @@ -337,17 +350,20 @@ private[spark] object JettyUtils extends Logging { // The number of selectors always equals to the number of acceptors minThreads += connector.getAcceptors * 2 } - server.setConnectors(connectors.toArray) pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) val errorHandler = new ErrorHandler() errorHandler.setShowStacks(true) errorHandler.setServer(server) server.addBean(errorHandler) + + gzipHandlers.foreach(collection.addHandler) server.setHandler(collection) + + server.setConnectors(connectors.toArray) try { server.start() - (server, httpConnector.getLocalPort) + ((server, httpsConnector.map(_.getLocalPort())), httpConnector.getLocalPort) } catch { case e: Exception => server.stop() @@ -356,13 +372,16 @@ private[spark] object JettyUtils extends Logging { } } - val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) - ServerInfo(server, boundPort, collection) + val ((server, securePort), boundPort) = Utils.startServiceOnPort(port, connect, conf, + serverName) + ServerInfo(server, boundPort, securePort, + server.getHandler().asInstanceOf[ContextHandlerCollection]) } private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = { val redirectHandler: ContextHandler = new ContextHandler redirectHandler.setContextPath("/") + redirectHandler.setVirtualHosts(Array("@" + REDIRECT_CONNECTOR_NAME)) redirectHandler.setHandler(new AbstractHandler { override def handle( target: String, @@ -442,7 +461,23 @@ private[spark] object JettyUtils extends Logging { private[spark] case class ServerInfo( server: Server, boundPort: Int, - rootHandler: ContextHandlerCollection) { + securePort: Option[Int], + private val rootHandler: ContextHandlerCollection) { + + def addHandler(handler: ContextHandler): Unit = { + handler.setVirtualHosts(Array("@" + JettyUtils.SPARK_CONNECTOR_NAME)) + rootHandler.addHandler(handler) + if (!handler.isStarted()) { + handler.start() + } + } + + def removeHandler(handler: ContextHandler): Unit = { + rootHandler.removeHandler(handler) + if (handler.isStarted) { + handler.stop() + } + } def stop(): Unit = { server.stop() diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index b8604c52e6b0..a9480cc220c8 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -91,23 +91,13 @@ private[spark] abstract class WebUI( /** Attach a handler to this UI. */ def attachHandler(handler: ServletContextHandler) { handlers += handler - serverInfo.foreach { info => - info.rootHandler.addHandler(handler) - if (!handler.isStarted) { - handler.start() - } - } + serverInfo.foreach(_.addHandler(handler)) } /** Detach a handler from this UI. */ def detachHandler(handler: ServletContextHandler) { handlers -= handler - serverInfo.foreach { info => - info.rootHandler.removeHandler(handler) - if (handler.isStarted) { - handler.stop() - } - } + serverInfo.foreach(_.removeHandler(handler)) } /** diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index f4786e3931c9..422837303642 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -475,8 +475,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val url = new URL( sc.ui.get.webUrl.stripSuffix("/") + "/stages/stage/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST - getResponseCode(url, "GET") should be (200) - getResponseCode(url, "POST") should be (200) + TestUtils.httpResponseCode(url, "GET") should be (200) + TestUtils.httpResponseCode(url, "POST") should be (200) } } } @@ -488,8 +488,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val url = new URL( sc.ui.get.webUrl.stripSuffix("/") + "/jobs/job/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST - getResponseCode(url, "GET") should be (200) - getResponseCode(url, "POST") should be (200) + TestUtils.httpResponseCode(url, "GET") should be (200) + TestUtils.httpResponseCode(url, "POST") should be (200) } } } @@ -671,17 +671,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } } - def getResponseCode(url: URL, method: String): Int = { - val connection = url.openConnection().asInstanceOf[HttpURLConnection] - connection.setRequestMethod(method) - try { - connection.connect() - connection.getResponseCode() - } finally { - connection.disconnect() - } - } - def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 68c7657cb315..aa67f49185e7 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.ui import java.net.{BindException, ServerSocket} -import java.net.URI -import javax.servlet.http.HttpServletRequest +import java.net.{URI, URL} +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.io.Source -import org.eclipse.jetty.servlet.ServletContextHandler +import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.mockito.Mockito.{mock, when} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -167,6 +167,7 @@ class UISuite extends SparkFunSuite { val boundPort = serverInfo.boundPort assert(server.getState === "STARTED") assert(boundPort != 0) + assert(serverInfo.securePort.isDefined) intercept[BindException] { socket = new ServerSocket(boundPort) } @@ -227,8 +228,55 @@ class UISuite extends SparkFunSuite { assert(newHeader === null) } + test("http -> https redirect applies to all URIs") { + var serverInfo: ServerInfo = null + try { + val servlet = new HttpServlet() { + override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = { + res.sendError(HttpServletResponse.SC_OK) + } + } + + def newContext(path: String): ServletContextHandler = { + val ctx = new ServletContextHandler() + ctx.setContextPath(path) + ctx.addServlet(new ServletHolder(servlet), "/root") + ctx + } + + val (conf, sslOptions) = sslEnabledConf() + serverInfo = JettyUtils.startJettyServer("0.0.0.0", 0, sslOptions, + Seq[ServletContextHandler](newContext("/"), newContext("/test1")), + conf) + assert(serverInfo.server.getState === "STARTED") + + val testContext = newContext("/test2") + serverInfo.addHandler(testContext) + testContext.start() + + val httpPort = serverInfo.boundPort + + val tests = Seq( + ("http", serverInfo.boundPort, HttpServletResponse.SC_FOUND), + ("https", serverInfo.securePort.get, HttpServletResponse.SC_OK)) + + tests.foreach { case (scheme, port, expected) => + val urls = Seq( + s"$scheme://localhost:$port/root", + s"$scheme://localhost:$port/test1/root", + s"$scheme://localhost:$port/test2/root") + urls.foreach { url => + val rc = TestUtils.httpResponseCode(new URL(url)) + assert(rc === expected, s"Unexpected status $rc for $url") + } + } + } finally { + stopServer(serverInfo) + } + } + def stopServer(info: ServerInfo): Unit = { - if (info != null && info.server != null) info.server.stop + if (info != null) info.stop() } def closeSocket(socket: ServerSocket): Unit = {