diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 26b18564be77..5a5c5a403f20 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -214,6 +214,15 @@ private[spark] class SecurityManager( */ def aclsEnabled(): Boolean = aclsOn + /** + * Checks whether the given user is an admin. This gives the user both view and + * modify permissions, and also allows the user to impersonate other users when + * making UI requests. + */ + def checkAdminPermissions(user: String): Boolean = { + isUserInACL(user, adminAcls, adminAclsGroups) + } + /** * Checks the given user against the view acl and groups list to see if they have * authorization to view the UI. If the UI acls are disabled @@ -227,13 +236,7 @@ private[spark] class SecurityManager( def checkUIViewPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " viewAcls=" + viewAcls.mkString(",") + " viewAclsGroups=" + viewAclsGroups.mkString(",")) - if (!aclsEnabled || user == null || viewAcls.contains(user) || - viewAcls.contains(WILDCARD_ACL) || viewAclsGroups.contains(WILDCARD_ACL)) { - return true - } - val currentUserGroups = Utils.getCurrentUserGroups(sparkConf, user) - logDebug("userGroups=" + currentUserGroups.mkString(",")) - viewAclsGroups.exists(currentUserGroups.contains(_)) + isUserInACL(user, viewAcls, viewAclsGroups) } /** @@ -249,13 +252,7 @@ private[spark] class SecurityManager( def checkModifyPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " modifyAcls=" + modifyAcls.mkString(",") + " modifyAclsGroups=" + modifyAclsGroups.mkString(",")) - if (!aclsEnabled || user == null || modifyAcls.contains(user) || - modifyAcls.contains(WILDCARD_ACL) || modifyAclsGroups.contains(WILDCARD_ACL)) { - return true - } - val currentUserGroups = Utils.getCurrentUserGroups(sparkConf, user) - logDebug("userGroups=" + currentUserGroups) - modifyAclsGroups.exists(currentUserGroups.contains(_)) + isUserInACL(user, modifyAcls, modifyAclsGroups) } /** @@ -371,6 +368,23 @@ private[spark] class SecurityManager( } } + private def isUserInACL( + user: String, + aclUsers: Set[String], + aclGroups: Set[String]): Boolean = { + if (user == null || + !aclsEnabled || + aclUsers.contains(WILDCARD_ACL) || + aclUsers.contains(user) || + aclGroups.contains(WILDCARD_ACL)) { + true + } else { + val userGroups = Utils.getCurrentUserGroups(sparkConf, user) + logDebug(s"user $user is in groups ${userGroups.mkString(",")}") + aclGroups.exists(userGroups.contains(_)) + } + } + // Default SecurityManager only has a single secret key, so ignore appId. override def getSaslUser(appId: String): String = getSaslUser() override def getSecretKey(appId: String): String = getSecretKey() diff --git a/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala b/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala index fc9b50f14a08..1c0dd7dee222 100644 --- a/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala +++ b/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala @@ -53,9 +53,24 @@ private class HttpSecurityFilter( val hres = res.asInstanceOf[HttpServletResponse] hres.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - if (!securityMgr.checkUIViewPermissions(hreq.getRemoteUser())) { + val requestUser = hreq.getRemoteUser() + + // The doAs parameter allows proxy servers (e.g. Knox) to impersonate other users. For + // that to be allowed, the authenticated user needs to be an admin. + val effectiveUser = Option(hreq.getParameter("doAs")) + .map { proxy => + if (requestUser != proxy && !securityMgr.checkAdminPermissions(requestUser)) { + hres.sendError(HttpServletResponse.SC_FORBIDDEN, + s"User $requestUser is not allowed to impersonate others.") + return + } + proxy + } + .getOrElse(requestUser) + + if (!securityMgr.checkUIViewPermissions(effectiveUser)) { hres.sendError(HttpServletResponse.SC_FORBIDDEN, - "User is not authorized to access this page.") + s"User $effectiveUser is not authorized to access this page.") return } @@ -77,12 +92,13 @@ private class HttpSecurityFilter( hres.setHeader("Strict-Transport-Security", _)) } - chain.doFilter(new XssSafeRequest(hreq), res) + chain.doFilter(new XssSafeRequest(hreq, effectiveUser), res) } } -private class XssSafeRequest(req: HttpServletRequest) extends HttpServletRequestWrapper(req) { +private class XssSafeRequest(req: HttpServletRequest, effectiveUser: String) + extends HttpServletRequestWrapper(req) { private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r @@ -92,6 +108,8 @@ private class XssSafeRequest(req: HttpServletRequest) extends HttpServletRequest }.toMap } + override def getRemoteUser(): String = effectiveUser + override def getParameterMap(): JMap[String, Array[String]] = parameterMap.asJava override def getParameterNames(): Enumeration[String] = { diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 20421456cefb..b31a6b4e2f9a 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -226,6 +226,8 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { securityManager.setAdminAcls(Seq("user6")) securityManager.setViewAcls(Set[String]("user8"), Seq("user9")) securityManager.setModifyAcls(Set("user11"), Seq("user9")) + assert(securityManager.checkAdminPermissions("user6")) + assert(!securityManager.checkAdminPermissions("user8")) assert(securityManager.checkModifyPermissions("user6")) assert(securityManager.checkModifyPermissions("user11")) assert(securityManager.checkModifyPermissions("user9")) @@ -252,6 +254,7 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(securityManager.aclsEnabled()) // group1,group2,group3 match + assert(securityManager.checkAdminPermissions("user1")) assert(securityManager.checkModifyPermissions("user1")) assert(securityManager.checkUIViewPermissions("user1")) @@ -261,8 +264,9 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { securityManager.setViewAclsGroups(Nil) securityManager.setModifyAclsGroups(Nil) - assert(securityManager.checkModifyPermissions("user1") === false) - assert(securityManager.checkUIViewPermissions("user1") === false) + assert(!securityManager.checkAdminPermissions("user1")) + assert(!securityManager.checkModifyPermissions("user1")) + assert(!securityManager.checkUIViewPermissions("user1")) // change modify groups so they match securityManager.setModifyAclsGroups(Seq("group3")) diff --git a/core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala b/core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala index 098d012eed88..c435852a4670 100644 --- a/core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala @@ -46,16 +46,8 @@ class HttpSecurityFilterSuite extends SparkFunSuite { val conf = new SparkConf() val filter = new HttpSecurityFilter(conf, new SecurityManager(conf)) - def newRequest(): HttpServletRequest = { - val req = mock(classOf[HttpServletRequest]) - when(req.getParameterMap()).thenReturn(Map.empty[String, Array[String]].asJava) - req - } - def doRequest(k: String, v: String): HttpServletRequest = { - val req = newRequest() - when(req.getParameterMap()).thenReturn(Map(k -> Array(v)).asJava) - + val req = mockRequest(params = Map(k -> Array(v))) val chain = mock(classOf[FilterChain]) val res = mock(classOf[HttpServletResponse]) filter.doFilter(req, res, chain) @@ -97,7 +89,7 @@ class HttpSecurityFilterSuite extends SparkFunSuite { .set(UI_VIEW_ACLS, Seq("alice")) val secMgr = new SecurityManager(conf) - val req = mockEmptyRequest() + val req = mockRequest() val res = mock(classOf[HttpServletResponse]) val chain = mock(classOf[FilterChain]) @@ -128,7 +120,7 @@ class HttpSecurityFilterSuite extends SparkFunSuite { .set(UI_X_CONTENT_TYPE_OPTIONS, true) .set(UI_STRICT_TRANSPORT_SECURITY, "tsec") val secMgr = new SecurityManager(conf) - val req = mockEmptyRequest() + val req = mockRequest() val res = mock(classOf[HttpServletResponse]) val chain = mock(classOf[FilterChain]) @@ -147,8 +139,43 @@ class HttpSecurityFilterSuite extends SparkFunSuite { } } - private def mockEmptyRequest(): HttpServletRequest = { - val params: Map[String, Array[String]] = Map.empty + test("doAs impersonation") { + val conf = new SparkConf(false) + .set(ACLS_ENABLE, true) + .set(ADMIN_ACLS, Seq("admin")) + .set(UI_VIEW_ACLS, Seq("proxy")) + + val secMgr = new SecurityManager(conf) + val req = mockRequest() + val res = mock(classOf[HttpServletResponse]) + val chain = mock(classOf[FilterChain]) + val filter = new HttpSecurityFilter(conf, secMgr) + + // First try with a non-admin so that the admin check is verified. This ensures that + // the admin check is setting the expected error, since the impersonated user would + // have permissions to process the request. + when(req.getParameter("doAs")).thenReturn("proxy") + when(req.getRemoteUser()).thenReturn("bob") + filter.doFilter(req, res, chain) + verify(res, times(1)).sendError(meq(HttpServletResponse.SC_FORBIDDEN), any()) + + when(req.getRemoteUser()).thenReturn("admin") + filter.doFilter(req, res, chain) + verify(chain, times(1)).doFilter(any(), any()) + + // Check that impersonation was actually performed by checking the wrapped request. + val captor = ArgumentCaptor.forClass(classOf[HttpServletRequest]) + verify(chain).doFilter(captor.capture(), any()) + val wrapped = captor.getValue() + assert(wrapped.getRemoteUser() === "proxy") + + // Impersonating a user without view permissions should cause an error. + when(req.getParameter("doAs")).thenReturn("alice") + filter.doFilter(req, res, chain) + verify(res, times(2)).sendError(meq(HttpServletResponse.SC_FORBIDDEN), any()) + } + + private def mockRequest(params: Map[String, Array[String]] = Map()): HttpServletRequest = { val req = mock(classOf[HttpServletRequest]) when(req.getParameterMap()).thenReturn(params.asJava) req