Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions core/src/main/scala/org/apache/spark/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
package org.apache.spark

import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
import java.net.{URI, URL}
import java.net.{HttpURLConnection, URI, URL}
import java.nio.charset.StandardCharsets
import java.security.SecureRandom
import java.security.cert.X509Certificate
import java.util.Arrays
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.jar.{JarEntry, JarOutputStream}
import javax.net.ssl._
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand All @@ -31,7 +35,6 @@ import scala.sys.process.{Process, ProcessLogger}
import scala.util.Try

import com.google.common.io.{ByteStreams, Files}
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
Expand Down Expand Up @@ -194,6 +197,37 @@ private[spark] object TestUtils {
attempt.isSuccess && attempt.get == 0
}

/**
* Returns the response code from an HTTP(S) URL.
*/
def httpResponseCode(url: URL, method: String = "GET"): Int = {
val connection = url.openConnection().asInstanceOf[HttpURLConnection]
connection.setRequestMethod(method)

// Disable cert and host name validation for HTTPS tests.
if (connection.isInstanceOf[HttpsURLConnection]) {
val sslCtx = SSLContext.getInstance("SSL")
val trustManager = new X509TrustManager {
override def getAcceptedIssuers(): Array[X509Certificate] = null
override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {}
override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {}
}
val verifier = new HostnameVerifier() {
override def verify(hostname: String, session: SSLSession): Boolean = true
}
sslCtx.init(null, Array(trustManager), new SecureRandom())
connection.asInstanceOf[HttpsURLConnection].setSSLSocketFactory(sslCtx.getSocketFactory())
connection.asInstanceOf[HttpsURLConnection].setHostnameVerifier(verifier)
}

try {
connection.connect()
connection.getResponseCode()
} finally {
connection.disconnect()
}
}

}


Expand Down
87 changes: 61 additions & 26 deletions core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ import org.apache.spark.util.Utils
*/
private[spark] object JettyUtils extends Logging {

val SPARK_CONNECTOR_NAME = "Spark"
val REDIRECT_CONNECTOR_NAME = "HttpsRedirect"

// Base type for a function that returns something based on an HTTP request. Allows for
// implicit conversion from many types of functions to jetty Handlers.
type Responder[T] = HttpServletRequest => T
Expand Down Expand Up @@ -274,25 +277,28 @@ private[spark] object JettyUtils extends Logging {
conf: SparkConf,
serverName: String = ""): ServerInfo = {

val collection = new ContextHandlerCollection
addFilters(handlers, conf)

val gzipHandlers = handlers.map { h =>
h.setVirtualHosts(Array("@" + SPARK_CONNECTOR_NAME))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this code here? setVirtualHosts should always be called in addHandler for each handler right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ContextHandlerCollection.addHandler does not call setVirtualHosts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addHandler here means ServerInfo#addHandler, sorry for confusing.
And I noticed serverInfo in WebUI can be None and serverInfo.foreach(_.addHandler(handler)) is not called within WebUI#attachHandler in that case. So I understand it's reasonable change.
Can you add a test case for justification of this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what you mean. Without this change the UI does not work at all. The test I added already covers it.

Copy link
Member

@sarutak sarutak Jan 25, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this change the UI does not work at all. The test I added already covers it.

Hmm, it's funny. I commented out this change and ran the test case you added (UISuite and UISeleniumSuite) but it passed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course I understand this change is needed. I've confirmed it manually.


val gzipHandler = new GzipHandler
gzipHandler.setHandler(h)
gzipHandler
}

// Bind to the given port, or throw a java.net.BindException if the port is occupied
def connect(currentPort: Int): (Server, Int) = {
def connect(currentPort: Int): ((Server, Option[Int]), Int) = {
val pool = new QueuedThreadPool
if (serverName.nonEmpty) {
pool.setName(serverName)
}
pool.setDaemon(true)

val server = new Server(pool)
val connectors = new ArrayBuffer[ServerConnector]
val connectors = new ArrayBuffer[ServerConnector]()
val collection = new ContextHandlerCollection

// Create a connector on port currentPort to listen for HTTP requests
val httpConnector = new ServerConnector(
server,
Expand All @@ -306,26 +312,33 @@ private[spark] object JettyUtils extends Logging {
httpConnector.setPort(currentPort)
connectors += httpConnector

sslOptions.createJettySslContextFactory().foreach { factory =>
// If the new port wraps around, do not try a privileged port.
val securePort =
if (currentPort != 0) {
(currentPort + 400 - 1024) % (65536 - 1024) + 1024
} else {
0
}
val scheme = "https"
// Create a connector on port securePort to listen for HTTPS requests
val connector = new ServerConnector(server, factory)
connector.setPort(securePort)

connectors += connector

// redirect the HTTP requests to HTTPS port
collection.addHandler(createRedirectHttpsHandler(securePort, scheme))
val httpsConnector = sslOptions.createJettySslContextFactory() match {
case Some(factory) =>
// If the new port wraps around, do not try a privileged port.
val securePort =
if (currentPort != 0) {
(currentPort + 400 - 1024) % (65536 - 1024) + 1024
} else {
0
}
val scheme = "https"
// Create a connector on port securePort to listen for HTTPS requests
val connector = new ServerConnector(server, factory)
connector.setPort(securePort)
connector.setName(SPARK_CONNECTOR_NAME)
connectors += connector

// redirect the HTTP requests to HTTPS port
httpConnector.setName(REDIRECT_CONNECTOR_NAME)
collection.addHandler(createRedirectHttpsHandler(securePort, scheme))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed one point.
If a port is already used, collection.addHandler will take place more than twice leading redirection doesn't work properly.
Of course, it's not your fault. If you fix it in this PR together, it's good but it's a separate issue so I'll fix in another PR otherwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually that is caused by my change, so let me fix it.

Some(connector)

case None =>
// No SSL, so the HTTP connector becomes the official one where all contexts bind.
httpConnector.setName(SPARK_CONNECTOR_NAME)
None
}

gzipHandlers.foreach(collection.addHandler)
// As each acceptor and each selector will use one thread, the number of threads should at
// least be the number of acceptors and selectors plus 1. (See SPARK-13776)
var minThreads = 1
Expand All @@ -337,17 +350,20 @@ private[spark] object JettyUtils extends Logging {
// The number of selectors always equals to the number of acceptors
minThreads += connector.getAcceptors * 2
}
server.setConnectors(connectors.toArray)
pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads))

val errorHandler = new ErrorHandler()
errorHandler.setShowStacks(true)
errorHandler.setServer(server)
server.addBean(errorHandler)

gzipHandlers.foreach(collection.addHandler)
server.setHandler(collection)

server.setConnectors(connectors.toArray)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you move server.setConnectors(connectors.toArray) and gzipHandlers.foreach(collection.addHandler)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly for grouping. "This is where all handlers are added to the server."

In any case I have another change (#16625) that kinda moves all this stuff around anyway...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

O.K. That's what I thought.

try {
server.start()
(server, httpConnector.getLocalPort)
((server, httpsConnector.map(_.getLocalPort())), httpConnector.getLocalPort)
} catch {
case e: Exception =>
server.stop()
Expand All @@ -356,13 +372,16 @@ private[spark] object JettyUtils extends Logging {
}
}

val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName)
ServerInfo(server, boundPort, collection)
val ((server, securePort), boundPort) = Utils.startServiceOnPort(port, connect, conf,
serverName)
ServerInfo(server, boundPort, securePort,
server.getHandler().asInstanceOf[ContextHandlerCollection])
}

private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = {
val redirectHandler: ContextHandler = new ContextHandler
redirectHandler.setContextPath("/")
redirectHandler.setVirtualHosts(Array("@" + REDIRECT_CONNECTOR_NAME))
redirectHandler.setHandler(new AbstractHandler {
override def handle(
target: String,
Expand Down Expand Up @@ -442,7 +461,23 @@ private[spark] object JettyUtils extends Logging {
private[spark] case class ServerInfo(
server: Server,
boundPort: Int,
rootHandler: ContextHandlerCollection) {
securePort: Option[Int],
private val rootHandler: ContextHandlerCollection) {

def addHandler(handler: ContextHandler): Unit = {
handler.setVirtualHosts(Array("@" + JettyUtils.SPARK_CONNECTOR_NAME))
rootHandler.addHandler(handler)
if (!handler.isStarted()) {
handler.start()
}
}

def removeHandler(handler: ContextHandler): Unit = {
rootHandler.removeHandler(handler)
if (handler.isStarted) {
handler.stop()
}
}

def stop(): Unit = {
server.stop()
Expand Down
14 changes: 2 additions & 12 deletions core/src/main/scala/org/apache/spark/ui/WebUI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,23 +91,13 @@ private[spark] abstract class WebUI(
/** Attach a handler to this UI. */
def attachHandler(handler: ServletContextHandler) {
handlers += handler
serverInfo.foreach { info =>
info.rootHandler.addHandler(handler)
if (!handler.isStarted) {
handler.start()
}
}
serverInfo.foreach(_.addHandler(handler))
}

/** Detach a handler from this UI. */
def detachHandler(handler: ServletContextHandler) {
handlers -= handler
serverInfo.foreach { info =>
info.rootHandler.removeHandler(handler)
if (handler.isStarted) {
handler.stop()
}
}
serverInfo.foreach(_.removeHandler(handler))
}

/**
Expand Down
19 changes: 4 additions & 15 deletions core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
val url = new URL(
sc.ui.get.webUrl.stripSuffix("/") + "/stages/stage/kill/?id=0")
// SPARK-6846: should be POST only but YARN AM doesn't proxy POST
getResponseCode(url, "GET") should be (200)
getResponseCode(url, "POST") should be (200)
TestUtils.httpResponseCode(url, "GET") should be (200)
TestUtils.httpResponseCode(url, "POST") should be (200)
}
}
}
Expand All @@ -488,8 +488,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
val url = new URL(
sc.ui.get.webUrl.stripSuffix("/") + "/jobs/job/kill/?id=0")
// SPARK-6846: should be POST only but YARN AM doesn't proxy POST
getResponseCode(url, "GET") should be (200)
getResponseCode(url, "POST") should be (200)
TestUtils.httpResponseCode(url, "GET") should be (200)
TestUtils.httpResponseCode(url, "POST") should be (200)
}
}
}
Expand Down Expand Up @@ -671,17 +671,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
}
}

def getResponseCode(url: URL, method: String): Int = {
val connection = url.openConnection().asInstanceOf[HttpURLConnection]
connection.setRequestMethod(method)
try {
connection.connect()
connection.getResponseCode()
} finally {
connection.disconnect()
}
}

def goToUi(sc: SparkContext, path: String): Unit = {
goToUi(sc.ui.get, path)
}
Expand Down
56 changes: 52 additions & 4 deletions core/src/test/scala/org/apache/spark/ui/UISuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
package org.apache.spark.ui

import java.net.{BindException, ServerSocket}
import java.net.URI
import javax.servlet.http.HttpServletRequest
import java.net.{URI, URL}
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}

import scala.io.Source

import org.eclipse.jetty.servlet.ServletContextHandler
import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
import org.mockito.Mockito.{mock, when}
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
Expand Down Expand Up @@ -167,6 +167,7 @@ class UISuite extends SparkFunSuite {
val boundPort = serverInfo.boundPort
assert(server.getState === "STARTED")
assert(boundPort != 0)
assert(serverInfo.securePort.isDefined)
intercept[BindException] {
socket = new ServerSocket(boundPort)
}
Expand Down Expand Up @@ -227,8 +228,55 @@ class UISuite extends SparkFunSuite {
assert(newHeader === null)
}

test("http -> https redirect applies to all URIs") {
var serverInfo: ServerInfo = null
try {
val servlet = new HttpServlet() {
override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = {
res.sendError(HttpServletResponse.SC_OK)
}
}

def newContext(path: String): ServletContextHandler = {
val ctx = new ServletContextHandler()
ctx.setContextPath(path)
ctx.addServlet(new ServletHolder(servlet), "/root")
ctx
}

val (conf, sslOptions) = sslEnabledConf()
serverInfo = JettyUtils.startJettyServer("0.0.0.0", 0, sslOptions,
Seq[ServletContextHandler](newContext("/"), newContext("/test1")),
conf)
assert(serverInfo.server.getState === "STARTED")

val testContext = newContext("/test2")
serverInfo.addHandler(testContext)
testContext.start()

val httpPort = serverInfo.boundPort

val tests = Seq(
("http", serverInfo.boundPort, HttpServletResponse.SC_FOUND),
("https", serverInfo.securePort.get, HttpServletResponse.SC_OK))

tests.foreach { case (scheme, port, expected) =>
val urls = Seq(
s"$scheme://localhost:$port/root",
s"$scheme://localhost:$port/test1/root",
s"$scheme://localhost:$port/test2/root")
urls.foreach { url =>
val rc = TestUtils.httpResponseCode(new URL(url))
assert(rc === expected, s"Unexpected status $rc for $url")
}
}
} finally {
stopServer(serverInfo)
}
}

def stopServer(info: ServerInfo): Unit = {
if (info != null && info.server != null) info.server.stop
if (info != null) info.stop()
}

def closeSocket(socket: ServerSocket): Unit = {
Expand Down