Skip to content

Commit 983f490

Browse files
author
Marcelo Vanzin
committed
[SPARK-19220][UI] Make redirection to HTTPS apply to all URIs.
The redirect handler was installed only for the root of the server; any other context ended up being served directly through the HTTP port. Since every sub page (e.g. application UIs in the history server) is a separate servlet context, this meant that everything but the root was accessible via HTTP still. The change adds separate names to each connector, and binds contexts to specific connectors so that content is only served through the HTTPS connector when it's enabled. In that case, the only thing that binds to the HTTP connector is the redirect handler. Tested with new unit tests and by checking a live history server.
1 parent b0e8eb6 commit 983f490

File tree

5 files changed

+147
-55
lines changed

5 files changed

+147
-55
lines changed

core/src/main/scala/org/apache/spark/TestUtils.scala

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@
1818
package org.apache.spark
1919

2020
import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
21-
import java.net.{URI, URL}
21+
import java.net.{HttpURLConnection, URI, URL}
2222
import java.nio.charset.StandardCharsets
23+
import java.security.SecureRandom
24+
import java.security.cert.X509Certificate
2325
import java.util.Arrays
2426
import java.util.concurrent.{CountDownLatch, TimeUnit}
2527
import java.util.jar.{JarEntry, JarOutputStream}
28+
import javax.net.ssl._
29+
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
2630

2731
import scala.collection.JavaConverters._
2832
import scala.collection.mutable
@@ -31,7 +35,6 @@ import scala.sys.process.{Process, ProcessLogger}
3135
import scala.util.Try
3236

3337
import com.google.common.io.{ByteStreams, Files}
34-
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
3538

3639
import org.apache.spark.executor.TaskMetrics
3740
import org.apache.spark.scheduler._
@@ -194,6 +197,37 @@ private[spark] object TestUtils {
194197
attempt.isSuccess && attempt.get == 0
195198
}
196199

200+
/**
201+
* Returns the response code from an HTTP(S) URL.
202+
*/
203+
def httpResponseCode(url: URL, method: String = "GET"): Int = {
204+
val connection = url.openConnection().asInstanceOf[HttpURLConnection]
205+
connection.setRequestMethod(method)
206+
207+
// Disable cert and host name validation for HTTPS tests.
208+
if (connection.isInstanceOf[HttpsURLConnection]) {
209+
val sslCtx = SSLContext.getInstance("SSL")
210+
val trustManager = new X509TrustManager {
211+
override def getAcceptedIssuers(): Array[X509Certificate] = null
212+
override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {}
213+
override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {}
214+
}
215+
val verifier = new HostnameVerifier() {
216+
override def verify(hostname: String, session: SSLSession): Boolean = true
217+
}
218+
sslCtx.init(null, Array(trustManager), new SecureRandom())
219+
connection.asInstanceOf[HttpsURLConnection].setSSLSocketFactory(sslCtx.getSocketFactory())
220+
connection.asInstanceOf[HttpsURLConnection].setHostnameVerifier(verifier)
221+
}
222+
223+
try {
224+
connection.connect()
225+
connection.getResponseCode()
226+
} finally {
227+
connection.disconnect()
228+
}
229+
}
230+
197231
}
198232

199233

core/src/main/scala/org/apache/spark/ui/JettyUtils.scala

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ import org.apache.spark.util.Utils
4545
*/
4646
private[spark] object JettyUtils extends Logging {
4747

48+
val SPARK_CONNECTOR_NAME = "Spark"
49+
val REDIRECT_CONNECTOR_NAME = "HttpsRedirect"
50+
4851
// Base type for a function that returns something based on an HTTP request. Allows for
4952
// implicit conversion from many types of functions to jetty Handlers.
5053
type Responder[T] = HttpServletRequest => T
@@ -278,13 +281,15 @@ private[spark] object JettyUtils extends Logging {
278281
addFilters(handlers, conf)
279282

280283
val gzipHandlers = handlers.map { h =>
284+
// h.setVirtualHosts(Array("@" + SPARK_CONNECTOR_NAME))
285+
281286
val gzipHandler = new GzipHandler
282287
gzipHandler.setHandler(h)
283288
gzipHandler
284289
}
285290

286291
// Bind to the given port, or throw a java.net.BindException if the port is occupied
287-
def connect(currentPort: Int): (Server, Int) = {
292+
def connect(currentPort: Int): ((Server, Option[Int]), Int) = {
288293
val pool = new QueuedThreadPool
289294
if (serverName.nonEmpty) {
290295
pool.setName(serverName)
@@ -306,23 +311,31 @@ private[spark] object JettyUtils extends Logging {
306311
httpConnector.setPort(currentPort)
307312
connectors += httpConnector
308313

309-
sslOptions.createJettySslContextFactory().foreach { factory =>
310-
// If the new port wraps around, do not try a privileged port.
311-
val securePort =
312-
if (currentPort != 0) {
313-
(currentPort + 400 - 1024) % (65536 - 1024) + 1024
314-
} else {
315-
0
316-
}
317-
val scheme = "https"
318-
// Create a connector on port securePort to listen for HTTPS requests
319-
val connector = new ServerConnector(server, factory)
320-
connector.setPort(securePort)
321-
322-
connectors += connector
323-
324-
// redirect the HTTP requests to HTTPS port
325-
collection.addHandler(createRedirectHttpsHandler(securePort, scheme))
314+
val httpsConnector = sslOptions.createJettySslContextFactory() match {
315+
case Some(factory) =>
316+
// If the new port wraps around, do not try a privileged port.
317+
val securePort =
318+
if (currentPort != 0) {
319+
(currentPort + 400 - 1024) % (65536 - 1024) + 1024
320+
} else {
321+
0
322+
}
323+
val scheme = "https"
324+
// Create a connector on port securePort to listen for HTTPS requests
325+
val connector = new ServerConnector(server, factory)
326+
connector.setPort(securePort)
327+
connector.setName(SPARK_CONNECTOR_NAME)
328+
connectors += connector
329+
330+
// redirect the HTTP requests to HTTPS port
331+
httpConnector.setName(REDIRECT_CONNECTOR_NAME)
332+
collection.addHandler(createRedirectHttpsHandler(securePort, scheme))
333+
Some(connector)
334+
335+
case None =>
336+
// No SSL, so the HTTP connector becomes the official one where all contexts bind.
337+
httpConnector.setName(SPARK_CONNECTOR_NAME)
338+
None
326339
}
327340

328341
gzipHandlers.foreach(collection.addHandler)
@@ -347,7 +360,7 @@ private[spark] object JettyUtils extends Logging {
347360
server.setHandler(collection)
348361
try {
349362
server.start()
350-
(server, httpConnector.getLocalPort)
363+
((server, httpsConnector.map(_.getLocalPort())), httpConnector.getLocalPort)
351364
} catch {
352365
case e: Exception =>
353366
server.stop()
@@ -356,13 +369,15 @@ private[spark] object JettyUtils extends Logging {
356369
}
357370
}
358371

359-
val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName)
360-
ServerInfo(server, boundPort, collection)
372+
val ((server, securePort), boundPort) = Utils.startServiceOnPort(port, connect, conf,
373+
serverName)
374+
ServerInfo(server, boundPort, securePort, collection)
361375
}
362376

363377
private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = {
364378
val redirectHandler: ContextHandler = new ContextHandler
365379
redirectHandler.setContextPath("/")
380+
redirectHandler.setVirtualHosts(Array("@" + REDIRECT_CONNECTOR_NAME))
366381
redirectHandler.setHandler(new AbstractHandler {
367382
override def handle(
368383
target: String,
@@ -442,7 +457,23 @@ private[spark] object JettyUtils extends Logging {
442457
private[spark] case class ServerInfo(
443458
server: Server,
444459
boundPort: Int,
445-
rootHandler: ContextHandlerCollection) {
460+
securePort: Option[Int],
461+
private val rootHandler: ContextHandlerCollection) {
462+
463+
def addHandler(handler: ContextHandler): Unit = {
464+
handler.setVirtualHosts(Array("@" + JettyUtils.SPARK_CONNECTOR_NAME))
465+
rootHandler.addHandler(handler)
466+
if (!handler.isStarted()) {
467+
handler.start()
468+
}
469+
}
470+
471+
def removeHandler(handler: ContextHandler): Unit = {
472+
rootHandler.removeHandler(handler)
473+
if (handler.isStarted) {
474+
handler.stop()
475+
}
476+
}
446477

447478
def stop(): Unit = {
448479
server.stop()

core/src/main/scala/org/apache/spark/ui/WebUI.scala

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,13 @@ private[spark] abstract class WebUI(
9191
/** Attach a handler to this UI. */
9292
def attachHandler(handler: ServletContextHandler) {
9393
handlers += handler
94-
serverInfo.foreach { info =>
95-
info.rootHandler.addHandler(handler)
96-
if (!handler.isStarted) {
97-
handler.start()
98-
}
99-
}
94+
serverInfo.foreach(_.addHandler(handler))
10095
}
10196

10297
/** Detach a handler from this UI. */
10398
def detachHandler(handler: ServletContextHandler) {
10499
handlers -= handler
105-
serverInfo.foreach { info =>
106-
info.rootHandler.removeHandler(handler)
107-
if (handler.isStarted) {
108-
handler.stop()
109-
}
110-
}
100+
serverInfo.foreach(_.removeHandler(handler))
111101
}
112102

113103
/**

core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -475,8 +475,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
475475
val url = new URL(
476476
sc.ui.get.webUrl.stripSuffix("/") + "/stages/stage/kill/?id=0")
477477
// SPARK-6846: should be POST only but YARN AM doesn't proxy POST
478-
getResponseCode(url, "GET") should be (200)
479-
getResponseCode(url, "POST") should be (200)
478+
TestUtils.httpResponseCode(url, "GET") should be (200)
479+
TestUtils.httpResponseCode(url, "POST") should be (200)
480480
}
481481
}
482482
}
@@ -488,8 +488,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
488488
val url = new URL(
489489
sc.ui.get.webUrl.stripSuffix("/") + "/jobs/job/kill/?id=0")
490490
// SPARK-6846: should be POST only but YARN AM doesn't proxy POST
491-
getResponseCode(url, "GET") should be (200)
492-
getResponseCode(url, "POST") should be (200)
491+
TestUtils.httpResponseCode(url, "GET") should be (200)
492+
TestUtils.httpResponseCode(url, "POST") should be (200)
493493
}
494494
}
495495
}
@@ -671,17 +671,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
671671
}
672672
}
673673

674-
def getResponseCode(url: URL, method: String): Int = {
675-
val connection = url.openConnection().asInstanceOf[HttpURLConnection]
676-
connection.setRequestMethod(method)
677-
try {
678-
connection.connect()
679-
connection.getResponseCode()
680-
} finally {
681-
connection.disconnect()
682-
}
683-
}
684-
685674
def goToUi(sc: SparkContext, path: String): Unit = {
686675
goToUi(sc.ui.get, path)
687676
}

core/src/test/scala/org/apache/spark/ui/UISuite.scala

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
package org.apache.spark.ui
1919

2020
import java.net.{BindException, ServerSocket}
21-
import java.net.URI
22-
import javax.servlet.http.HttpServletRequest
21+
import java.net.{URI, URL}
22+
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
2323

2424
import scala.io.Source
2525

26-
import org.eclipse.jetty.servlet.ServletContextHandler
26+
import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
2727
import org.mockito.Mockito.{mock, when}
2828
import org.scalatest.concurrent.Eventually._
2929
import org.scalatest.time.SpanSugar._
@@ -167,6 +167,7 @@ class UISuite extends SparkFunSuite {
167167
val boundPort = serverInfo.boundPort
168168
assert(server.getState === "STARTED")
169169
assert(boundPort != 0)
170+
assert(serverInfo.securePort.isDefined)
170171
intercept[BindException] {
171172
socket = new ServerSocket(boundPort)
172173
}
@@ -227,8 +228,55 @@ class UISuite extends SparkFunSuite {
227228
assert(newHeader === null)
228229
}
229230

231+
test("http -> https redirect applies to all URIs") {
232+
var serverInfo: ServerInfo = null
233+
try {
234+
val servlet = new HttpServlet() {
235+
override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = {
236+
res.sendError(HttpServletResponse.SC_OK)
237+
}
238+
}
239+
240+
def newContext(path: String): ServletContextHandler = {
241+
val ctx = new ServletContextHandler()
242+
ctx.setContextPath(path)
243+
ctx.addServlet(new ServletHolder(servlet), "/*")
244+
ctx
245+
}
246+
247+
val (conf, sslOptions) = sslEnabledConf()
248+
serverInfo = JettyUtils.startJettyServer(
249+
"0.0.0.0", 0, sslOptions, Seq[ServletContextHandler](newContext("/")), conf)
250+
assert(serverInfo.server.getState === "STARTED")
251+
252+
val testContext = newContext("/test")
253+
serverInfo.addHandler(testContext)
254+
testContext.start()
255+
256+
val httpPort = serverInfo.boundPort
257+
258+
val tests = Seq(
259+
("http", serverInfo.boundPort, HttpServletResponse.SC_FOUND),
260+
("https", serverInfo.securePort.get, HttpServletResponse.SC_OK))
261+
262+
tests.foreach { case (scheme, port, expected) =>
263+
val urls = Seq(
264+
s"$scheme://localhost:$port",
265+
s"$scheme://localhost:$port/",
266+
s"$scheme://localhost:$port/test",
267+
s"$scheme://localhost:$port/test/foo")
268+
urls.foreach { url =>
269+
val rc = TestUtils.httpResponseCode(new URL(url))
270+
assert(rc === expected, s"Unexpected status $rc for $url")
271+
}
272+
}
273+
} finally {
274+
stopServer(serverInfo)
275+
}
276+
}
277+
230278
def stopServer(info: ServerInfo): Unit = {
231-
if (info != null && info.server != null) info.server.stop
279+
if (info != null) info.stop()
232280
}
233281

234282
def closeSocket(socket: ServerSocket): Unit = {

0 commit comments

Comments
 (0)