Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion core/src/main/scala/org/apache/spark/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,13 @@ private[spark] object TestUtils {
/**
* Returns the response code from an HTTP(S) URL.
*/
def httpResponseCode(url: URL, method: String = "GET"): Int = {
def httpResponseCode(
url: URL,
method: String = "GET",
headers: Seq[(String, String)] = Nil): Int = {
val connection = url.openConnection().asInstanceOf[HttpURLConnection]
connection.setRequestMethod(method)
headers.foreach { case (k, v) => connection.setRequestProperty(k, v) }

// Disable cert and host name validation for HTTPS tests.
if (connection.isInstanceOf[HttpsURLConnection]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.status.api.v1

import java.util.zip.ZipOutputStream
import javax.servlet.ServletContext
import javax.servlet.http.HttpServletRequest
import javax.ws.rs._
import javax.ws.rs.core.{Context, Response}

Expand All @@ -40,7 +41,7 @@ import org.apache.spark.ui.SparkUI
* HistoryServerSuite.
*/
@Path("/v1")
private[v1] class ApiRootResource extends UIRootFromServletContext {
private[v1] class ApiRootResource extends ApiRequestContext {

@Path("applications")
def getApplicationList(): ApplicationListResource = {
Expand All @@ -56,21 +57,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getJobs(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllJobsResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
withSparkUI(appId, Some(attemptId)) { ui =>
new AllJobsResource(ui)
}
}

@Path("applications/{appId}/jobs")
def getJobs(@PathParam("appId") appId: String): AllJobsResource = {
uiRoot.withSparkUI(appId, None) { ui =>
withSparkUI(appId, None) { ui =>
new AllJobsResource(ui)
}
}

@Path("applications/{appId}/jobs/{jobId: \\d+}")
def getJob(@PathParam("appId") appId: String): OneJobResource = {
uiRoot.withSparkUI(appId, None) { ui =>
withSparkUI(appId, None) { ui =>
new OneJobResource(ui)
}
}
Expand All @@ -79,21 +80,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getJob(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): OneJobResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
withSparkUI(appId, Some(attemptId)) { ui =>
new OneJobResource(ui)
}
}

@Path("applications/{appId}/executors")
def getExecutors(@PathParam("appId") appId: String): ExecutorListResource = {
uiRoot.withSparkUI(appId, None) { ui =>
withSparkUI(appId, None) { ui =>
new ExecutorListResource(ui)
}
}

@Path("applications/{appId}/allexecutors")
def getAllExecutors(@PathParam("appId") appId: String): AllExecutorListResource = {
uiRoot.withSparkUI(appId, None) { ui =>
withSparkUI(appId, None) { ui =>
new AllExecutorListResource(ui)
}
}
Expand All @@ -102,7 +103,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getExecutors(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): ExecutorListResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
withSparkUI(appId, Some(attemptId)) { ui =>
new ExecutorListResource(ui)
}
}
Expand All @@ -111,15 +112,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getAllExecutors(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllExecutorListResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
withSparkUI(appId, Some(attemptId)) { ui =>
new AllExecutorListResource(ui)
}
}


@Path("applications/{appId}/stages")
def getStages(@PathParam("appId") appId: String): AllStagesResource = {
uiRoot.withSparkUI(appId, None) { ui =>
withSparkUI(appId, None) { ui =>
new AllStagesResource(ui)
}
}
Expand All @@ -128,14 +128,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getStages(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllStagesResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
withSparkUI(appId, Some(attemptId)) { ui =>
new AllStagesResource(ui)
}
}

@Path("applications/{appId}/stages/{stageId: \\d+}")
def getStage(@PathParam("appId") appId: String): OneStageResource = {
uiRoot.withSparkUI(appId, None) { ui =>
withSparkUI(appId, None) { ui =>
new OneStageResource(ui)
}
}
Expand All @@ -144,14 +144,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getStage(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): OneStageResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
withSparkUI(appId, Some(attemptId)) { ui =>
new OneStageResource(ui)
}
}

@Path("applications/{appId}/storage/rdd")
def getRdds(@PathParam("appId") appId: String): AllRDDResource = {
uiRoot.withSparkUI(appId, None) { ui =>
withSparkUI(appId, None) { ui =>
new AllRDDResource(ui)
}
}
Expand All @@ -160,14 +160,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getRdds(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllRDDResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
withSparkUI(appId, Some(attemptId)) { ui =>
new AllRDDResource(ui)
}
}

@Path("applications/{appId}/storage/rdd/{rddId: \\d+}")
def getRdd(@PathParam("appId") appId: String): OneRDDResource = {
uiRoot.withSparkUI(appId, None) { ui =>
withSparkUI(appId, None) { ui =>
new OneRDDResource(ui)
}
}
Expand All @@ -176,7 +176,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getRdd(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): OneRDDResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
withSparkUI(appId, Some(attemptId)) { ui =>
new OneRDDResource(ui)
}
}
Expand Down Expand Up @@ -234,19 +234,6 @@ private[spark] trait UIRoot {
.status(Response.Status.SERVICE_UNAVAILABLE)
.build()
}

/**
* Get the spark UI with the given appID, and apply a function
* to it. If there is no such app, throw an appropriate exception
*/
def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = {
val appKey = attemptId.map(appId + "/" + _).getOrElse(appId)
getSparkUI(appKey) match {
case Some(ui) =>
f(ui)
case None => throw new NotFoundException("no such app: " + appId)
}
}
def securityManager: SecurityManager
}

Expand All @@ -263,13 +250,38 @@ private[v1] object UIRootFromServletContext {
}
}

private[v1] trait UIRootFromServletContext {
private[v1] trait ApiRequestContext {
@Context
protected var servletContext: ServletContext = _

@Context
var servletContext: ServletContext = _
protected var httpRequest: HttpServletRequest = _

def uiRoot: UIRoot = UIRootFromServletContext.getUiRoot(servletContext)


/**
* Get the spark UI with the given appID, and apply a function
* to it. If there is no such app, throw an appropriate exception
*/
def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = {
val appKey = attemptId.map(appId + "/" + _).getOrElse(appId)
uiRoot.getSparkUI(appKey) match {
case Some(ui) =>
val user = httpRequest.getRemoteUser()
if (!ui.securityManager.checkUIViewPermissions(user)) {
throw new ForbiddenException(raw"""user "$user" is not authorized""")
}
f(ui)
case None => throw new NotFoundException("no such app: " + appId)
}
}

}

private[v1] class ForbiddenException(msg: String) extends WebApplicationException(
Response.status(Response.Status.FORBIDDEN).entity(msg).build())

private[v1] class NotFoundException(msg: String) extends WebApplicationException(
new NoSuchElementException(msg),
Response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ import javax.ws.rs.core.Response
import javax.ws.rs.ext.Provider

@Provider
private[v1] class SecurityFilter extends ContainerRequestFilter with UIRootFromServletContext {
private[v1] class SecurityFilter extends ContainerRequestFilter with ApiRequestContext {
override def filter(req: ContainerRequestContext): Unit = {
val user = Option(req.getSecurityContext.getUserPrincipal).map { _.getName }.orNull
val user = httpRequest.getRemoteUser()
if (!uiRoot.securityManager.checkUIViewPermissions(user)) {
req.abortWith(
Response
.status(Response.Status.FORBIDDEN)
.entity(raw"""user "$user"is not authorized""")
.entity(raw"""user "$user" is not authorized""")
.build()
)
}
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ private[spark] object JettyUtils extends Logging {
response.setHeader("X-Frame-Options", xFrameOptionsValue)
response.getWriter.print(servletParams.extractFn(result))
} else {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED)
response.setStatus(HttpServletResponse.SC_FORBIDDEN)
response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
response.sendError(HttpServletResponse.SC_UNAUTHORIZED,
response.sendError(HttpServletResponse.SC_FORBIDDEN,
"User is not authorized to access this page.")
}
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ import java.io.{File, FileInputStream, FileWriter, InputStream, IOException}
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.zip.ZipInputStream
import javax.servlet.http.{HttpServletRequest, HttpServletResponse}
import javax.servlet._
import javax.servlet.http.{HttpServletRequest, HttpServletRequestWrapper, HttpServletResponse}

import scala.concurrent.duration._
import scala.language.postfixOps
Expand Down Expand Up @@ -68,11 +69,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
private var server: HistoryServer = null
private var port: Int = -1

def init(): Unit = {
def init(extraConf: (String, String)*): Unit = {
val conf = new SparkConf()
.set("spark.history.fs.logDirectory", logDir)
.set("spark.history.fs.update.interval", "0")
.set("spark.testing", "true")
conf.setAll(extraConf)
provider = new FsHistoryProvider(conf)
provider.checkForLogs()
val securityManager = HistoryServer.createSecurityManager(conf)
Expand Down Expand Up @@ -566,6 +568,39 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers

}

test("ui and api authorization checks") {
val appId = "app-20161115172038-0000"
val owner = "jose"
val admin = "root"
val other = "alice"

stop()
init(
"spark.ui.filters" -> classOf[FakeAuthFilter].getName(),
"spark.history.ui.acls.enable" -> "true",
"spark.history.ui.admin.acls" -> admin)

val tests = Seq(
(owner, HttpServletResponse.SC_OK),
(admin, HttpServletResponse.SC_OK),
(other, HttpServletResponse.SC_FORBIDDEN),
// When the remote user is null, the code behaves as if auth were disabled.
(null, HttpServletResponse.SC_OK))

val port = server.boundPort
val testUrls = Seq(
s"http://localhost:$port/api/v1/applications/$appId/jobs",
s"http://localhost:$port/history/$appId/jobs/")

tests.foreach { case (user, expectedCode) =>
testUrls.foreach { url =>
val headers = if (user != null) Seq(FakeAuthFilter.FAKE_HTTP_USER -> user) else Nil
val sc = TestUtils.httpResponseCode(new URL(url), headers = headers)
assert(sc === expectedCode, s"Unexpected status code $sc for $url (user = $user)")
}
}
}

def getContentAndCode(path: String, port: Int = port): (Int, Option[String], Option[String]) = {
HistoryServerSuite.getContentAndCode(new URL(s"http://localhost:$port/api/v1/$path"))
}
Expand Down Expand Up @@ -648,3 +683,26 @@ object HistoryServerSuite {
}
}
}

/**
* A filter used for auth tests; sets the request's user to the value of the "HTTP_USER" header.
*/
class FakeAuthFilter extends Filter {

override def destroy(): Unit = { }

override def init(config: FilterConfig): Unit = { }

override def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = {
val hreq = req.asInstanceOf[HttpServletRequest]
val wrapped = new HttpServletRequestWrapper(hreq) {
override def getRemoteUser(): String = hreq.getHeader(FakeAuthFilter.FAKE_HTTP_USER)
}
chain.doFilter(wrapped, res)
}

}

object FakeAuthFilter {
val FAKE_HTTP_USER = "HTTP_USER"
}
4 changes: 4 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ object MimaExcludes {

// Exclude rules for 2.2.x
lazy val v22excludes = v21excludes ++ Seq(
// [SPARK-19652][UI] Do auth checks for REST API access.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.withSparkUI"),
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.status.api.v1.UIRootFromServletContext"),

// [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray"),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package org.apache.spark.status.api.v1.streaming

import javax.ws.rs.{Path, PathParam}

import org.apache.spark.status.api.v1.UIRootFromServletContext
import org.apache.spark.status.api.v1.ApiRequestContext

@Path("/v1")
private[v1] class ApiStreamingApp extends UIRootFromServletContext {
private[v1] class ApiStreamingApp extends ApiRequestContext {

@Path("applications/{appId}/streaming")
def getStreamingRoot(@PathParam("appId") appId: String): ApiStreamingRootResource = {
uiRoot.withSparkUI(appId, None) { ui =>
withSparkUI(appId, None) { ui =>
new ApiStreamingRootResource(ui)
}
}
Expand All @@ -35,7 +35,7 @@ private[v1] class ApiStreamingApp extends UIRootFromServletContext {
def getStreamingRoot(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): ApiStreamingRootResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
withSparkUI(appId, Some(attemptId)) { ui =>
new ApiStreamingRootResource(ui)
}
}
Expand Down