diff --git a/presto-main/pom.xml b/presto-main/pom.xml index 0a11ea5afdf16..2134ebca1e3a8 100644 --- a/presto-main/pom.xml +++ b/presto-main/pom.xml @@ -136,6 +136,11 @@ http-server + + com.facebook.airlift + security + + com.facebook.airlift jaxrs @@ -364,7 +369,6 @@ io.jsonwebtoken jjwt-jackson - runtime diff --git a/presto-main/src/main/java/com/facebook/presto/dispatcher/DispatchManager.java b/presto-main/src/main/java/com/facebook/presto/dispatcher/DispatchManager.java index edd0a86f25ab5..6fef7dc46d2f9 100644 --- a/presto-main/src/main/java/com/facebook/presto/dispatcher/DispatchManager.java +++ b/presto-main/src/main/java/com/facebook/presto/dispatcher/DispatchManager.java @@ -31,7 +31,6 @@ import com.facebook.presto.server.SessionContext; import com.facebook.presto.server.SessionPropertyDefaults; import com.facebook.presto.server.SessionSupplier; -import com.facebook.presto.server.security.SecurityConfig; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.analyzer.AnalyzerOptions; @@ -39,7 +38,6 @@ import com.facebook.presto.spi.resourceGroups.SelectionContext; import com.facebook.presto.spi.resourceGroups.SelectionCriteria; import com.facebook.presto.spi.security.AccessControl; -import com.facebook.presto.spi.security.AuthorizedIdentity; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; import com.facebook.presto.transaction.TransactionManager; import com.google.common.util.concurrent.AbstractFuture; @@ -57,8 +55,6 @@ import java.util.concurrent.Executor; import static com.facebook.presto.SystemSessionProperties.getAnalyzerType; -import static com.facebook.presto.security.AccessControlUtils.checkPermissions; -import static com.facebook.presto.security.AccessControlUtils.getAuthorizedIdentity; import static com.facebook.presto.spi.StandardErrorCode.QUERY_TEXT_TOO_LARGE; import static com.facebook.presto.util.AnalyzerUtil.createAnalyzerOptions; import static com.google.common.base.Preconditions.checkArgument; @@ -93,7 +89,6 @@ public class DispatchManager private final QueryManagerStats stats = new QueryManagerStats(); - private final SecurityConfig securityConfig; private final QueryPreparerProviderManager queryPreparerProviderManager; /** @@ -130,7 +125,6 @@ public DispatchManager( QueryManagerConfig queryManagerConfig, DispatchExecutor dispatchExecutor, ClusterStatusSender clusterStatusSender, - SecurityConfig securityConfig, Optional clusterQueryTrackerService) { this.queryIdGenerator = requireNonNull(queryIdGenerator, "queryIdGenerator is null"); @@ -152,8 +146,6 @@ public DispatchManager( this.clusterStatusSender = requireNonNull(clusterStatusSender, "clusterStatusSender is null"); this.queryTracker = new QueryTracker<>(queryManagerConfig, dispatchExecutor.getScheduledExecutor(), clusterQueryTrackerService); - - this.securityConfig = requireNonNull(securityConfig, "securityConfig is null"); } /** @@ -275,14 +267,8 @@ private void createQueryInternal(QueryId queryId, String slug, int retryCoun throw new PrestoException(QUERY_TEXT_TOO_LARGE, format("Query text length (%s) exceeds the maximum length (%s)", queryLength, maxQueryLength)); } - // check permissions if needed - checkPermissions(accessControl, securityConfig, queryId, sessionContext); - - // get authorized identity if possible - Optional authorizedIdentity = getAuthorizedIdentity(accessControl, securityConfig, queryId, sessionContext); - // decode session - session = sessionSupplier.createSession(queryId, sessionContext, warningCollectorFactory, authorizedIdentity); + session = sessionSupplier.createSession(queryId, sessionContext, warningCollectorFactory); // prepare query AnalyzerOptions analyzerOptions = createAnalyzerOptions(session, session.getWarningCollector()); diff --git a/presto-main/src/main/java/com/facebook/presto/server/HttpRequestSessionContext.java b/presto-main/src/main/java/com/facebook/presto/server/HttpRequestSessionContext.java index d98e6e4310403..6fc8b9efbd667 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/HttpRequestSessionContext.java +++ b/presto-main/src/main/java/com/facebook/presto/server/HttpRequestSessionContext.java @@ -20,6 +20,7 @@ import com.facebook.presto.metadata.SessionPropertyManager; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.security.AuthorizedIdentity; import com.facebook.presto.spi.security.Identity; import com.facebook.presto.spi.security.SelectedRole; import com.facebook.presto.spi.session.ResourceEstimates; @@ -76,6 +77,7 @@ import static com.facebook.presto.client.PrestoHeaders.PRESTO_TRACE_TOKEN; import static com.facebook.presto.client.PrestoHeaders.PRESTO_TRANSACTION_ID; import static com.facebook.presto.client.PrestoHeaders.PRESTO_USER; +import static com.facebook.presto.server.security.ServletSecurityUtils.authorizedIdentity; import static com.facebook.presto.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE; import static com.google.common.base.Strings.emptyToNull; import static com.google.common.base.Strings.isNullOrEmpty; @@ -99,6 +101,7 @@ public final class HttpRequestSessionContext private final String schema; private final Identity identity; + private final Optional authorizedIdentity; private final List certificates; private final String source; @@ -155,6 +158,7 @@ public HttpRequestSessionContext(HttpServletRequest servletRequest, SqlParserOpt ImmutableMap.of(), Optional.empty(), Optional.empty()); + authorizedIdentity = authorizedIdentity(servletRequest); X509Certificate[] certs = (X509Certificate[]) servletRequest.getAttribute(X509_ATTRIBUTE); if (certs != null && certs.length > 0) { @@ -404,6 +408,12 @@ public Identity getIdentity() return identity; } + @Override + public Optional getAuthorizedIdentity() + { + return authorizedIdentity; + } + @Override public List getCertificates() { diff --git a/presto-main/src/main/java/com/facebook/presto/server/NoOpSessionSupplier.java b/presto-main/src/main/java/com/facebook/presto/server/NoOpSessionSupplier.java index d27196aaf88f8..4b3ce0c3e7c04 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/NoOpSessionSupplier.java +++ b/presto-main/src/main/java/com/facebook/presto/server/NoOpSessionSupplier.java @@ -16,9 +16,6 @@ import com.facebook.presto.Session; import com.facebook.presto.execution.warnings.WarningCollectorFactory; import com.facebook.presto.spi.QueryId; -import com.facebook.presto.spi.security.AuthorizedIdentity; - -import java.util.Optional; /** * Used on workers. @@ -27,7 +24,7 @@ public class NoOpSessionSupplier implements SessionSupplier { @Override - public Session createSession(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory, Optional authorizedIdentity) + public Session createSession(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory) { throw new UnsupportedOperationException(); } diff --git a/presto-main/src/main/java/com/facebook/presto/server/QuerySessionSupplier.java b/presto-main/src/main/java/com/facebook/presto/server/QuerySessionSupplier.java index 0fecfaeb5ed34..2d4062f56abe6 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/QuerySessionSupplier.java +++ b/presto-main/src/main/java/com/facebook/presto/server/QuerySessionSupplier.java @@ -19,6 +19,7 @@ import com.facebook.presto.common.type.TimeZoneKey; import com.facebook.presto.execution.warnings.WarningCollectorFactory; import com.facebook.presto.metadata.SessionPropertyManager; +import com.facebook.presto.server.security.SecurityConfig; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.function.SqlFunctionId; @@ -39,6 +40,8 @@ import static com.facebook.presto.Session.SessionBuilder; import static com.facebook.presto.SystemSessionProperties.WARNING_HANDLING; import static com.facebook.presto.common.type.TimeZoneKey.getTimeZoneKey; +import static com.facebook.presto.security.AccessControlUtils.checkPermissions; +import static com.facebook.presto.security.AccessControlUtils.getAuthorizedIdentity; import static java.util.Map.Entry; import static java.util.Objects.requireNonNull; @@ -52,44 +55,30 @@ public class QuerySessionSupplier private final AccessControl accessControl; private final SessionPropertyManager sessionPropertyManager; private final Optional forcedSessionTimeZone; + private final SecurityConfig securityConfig; @Inject public QuerySessionSupplier( TransactionManager transactionManager, AccessControl accessControl, SessionPropertyManager sessionPropertyManager, - SqlEnvironmentConfig config) + SqlEnvironmentConfig sqlEnvironmentConfig, + SecurityConfig securityConfig) { this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); this.accessControl = requireNonNull(accessControl, "accessControl is null"); this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null"); - requireNonNull(config, "config is null"); - this.forcedSessionTimeZone = requireNonNull(config.getForcedSessionTimeZone(), "forcedSessionTimeZone is null"); + requireNonNull(sqlEnvironmentConfig, "sqlEnvironmentConfig is null"); + this.forcedSessionTimeZone = requireNonNull(sqlEnvironmentConfig.getForcedSessionTimeZone(), "forcedSessionTimeZone is null"); + this.securityConfig = requireNonNull(securityConfig, "securityConfig is null"); } @Override - public Session createSession(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory, Optional authorizedIdentity) + public Session createSession(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory) { - Identity identity = context.getIdentity(); - if (authorizedIdentity.isPresent()) { - identity = new Identity( - identity.getUser(), - identity.getPrincipal(), - identity.getRoles(), - identity.getExtraCredentials(), - identity.getExtraAuthenticators(), - Optional.of(authorizedIdentity.get().getUserName()), - authorizedIdentity.get().getReasonForSelect()); - log.info(String.format( - "For query %s, given user is %s, authorized user is %s", - queryId.getId(), - identity.getUser(), - authorizedIdentity.get().getUserName())); - } - SessionBuilder sessionBuilder = Session.builder(sessionPropertyManager) .setQueryId(queryId) - .setIdentity(identity) + .setIdentity(authenticateIdentity(queryId, context)) .setSource(context.getSource()) .setCatalog(context.getCatalog()) .setSchema(context.getSchema()) @@ -145,4 +134,20 @@ else if (context.getTimeZoneId() != null) { } return session; } + + private Identity authenticateIdentity(QueryId queryId, SessionContext context) + { + checkPermissions(accessControl, securityConfig, queryId, context); + Optional authorizedIdentity = context.getAuthorizedIdentity(); + authorizedIdentity = authorizedIdentity.isPresent() ? authorizedIdentity : getAuthorizedIdentity(accessControl, securityConfig, queryId, context); + + return authorizedIdentity.map(identity -> new Identity( + context.getIdentity().getUser(), + context.getIdentity().getPrincipal(), + context.getIdentity().getRoles(), + context.getIdentity().getExtraCredentials(), + context.getIdentity().getExtraAuthenticators(), + Optional.of(identity.getUserName()), + identity.getReasonForSelect())).orElseGet(context::getIdentity); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/SessionContext.java b/presto-main/src/main/java/com/facebook/presto/server/SessionContext.java index 705707b5c30d5..d40bb15482377 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/SessionContext.java +++ b/presto-main/src/main/java/com/facebook/presto/server/SessionContext.java @@ -17,6 +17,7 @@ import com.facebook.presto.common.transaction.TransactionId; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.security.AuthorizedIdentity; import com.facebook.presto.spi.security.Identity; import com.facebook.presto.spi.session.ResourceEstimates; import com.facebook.presto.spi.tracing.Tracer; @@ -34,6 +35,11 @@ public interface SessionContext { Identity getIdentity(); + default Optional getAuthorizedIdentity() + { + return Optional.empty(); + } + default List getCertificates() { return ImmutableList.of(); diff --git a/presto-main/src/main/java/com/facebook/presto/server/SessionSupplier.java b/presto-main/src/main/java/com/facebook/presto/server/SessionSupplier.java index 5e4367462cbaf..d021f25724407 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/SessionSupplier.java +++ b/presto-main/src/main/java/com/facebook/presto/server/SessionSupplier.java @@ -16,11 +16,8 @@ import com.facebook.presto.Session; import com.facebook.presto.execution.warnings.WarningCollectorFactory; import com.facebook.presto.spi.QueryId; -import com.facebook.presto.spi.security.AuthorizedIdentity; - -import java.util.Optional; public interface SessionSupplier { - Session createSession(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory, Optional authorizedIdentity); + Session createSession(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory); } diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/JsonWebTokenAuthenticator.java b/presto-main/src/main/java/com/facebook/presto/server/security/JsonWebTokenAuthenticator.java new file mode 100644 index 0000000000000..a8a209762f81d --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/JsonWebTokenAuthenticator.java @@ -0,0 +1,265 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security; + +import com.facebook.airlift.http.server.AuthenticationException; +import com.facebook.airlift.http.server.Authenticator; +import com.facebook.airlift.http.server.BasicPrincipal; +import com.facebook.airlift.security.pem.PemReader; +import com.facebook.presto.spi.security.AuthorizedIdentity; +import com.google.common.base.CharMatcher; +import com.google.common.collect.ImmutableMap; +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.Jws; +import io.jsonwebtoken.JwsHeader; +import io.jsonwebtoken.JwtException; +import io.jsonwebtoken.JwtParser; +import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.SignatureAlgorithm; +import io.jsonwebtoken.SignatureException; +import io.jsonwebtoken.SigningKeyResolver; +import io.jsonwebtoken.UnsupportedJwtException; +import io.jsonwebtoken.jackson.io.JacksonDeserializer; + +import javax.crypto.spec.SecretKeySpec; +import javax.inject.Inject; +import javax.servlet.http.HttpServletRequest; + +import java.io.File; +import java.io.IOException; +import java.security.Key; +import java.security.Principal; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.Function; + +import static com.facebook.presto.server.security.ServletSecurityUtils.AUTHORIZED_IDENTITY_ATTRIBUTE; +import static com.facebook.presto.server.security.ServletSecurityUtils.setAuthorizedIdentity; +import static com.google.common.base.CharMatcher.inRange; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.io.Files.asCharSource; +import static com.google.common.net.HttpHeaders.AUTHORIZATION; +import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.util.Base64.getMimeDecoder; +import static java.util.Objects.requireNonNull; + +public class JsonWebTokenAuthenticator + implements Authenticator +{ + private static final String DEFAULT_KEY = "default-key"; + private static final CharMatcher INVALID_KID_CHARS = inRange('a', 'z').or(inRange('A', 'Z')).or(inRange('0', '9')).or(CharMatcher.anyOf("_-")).negate(); + private static final String KEY_ID_VARIABLE = "${KID}"; + + private final JwtParser jwtParser; + private final Function, Key> keyLoader; + + @Inject + public JsonWebTokenAuthenticator(JsonWebTokenConfig config) + { + requireNonNull(config, "config is null"); + + if (config.getKeyFile().contains(KEY_ID_VARIABLE)) { + keyLoader = new DynamicKeyLoader(config.getKeyFile()); + } + else { + keyLoader = new StaticKeyLoader(config.getKeyFile()); + } + + JwtParser jwtParser = Jwts.parserBuilder() + .deserializeJsonWith(new JacksonDeserializer<>(ImmutableMap.of(AUTHORIZED_IDENTITY_ATTRIBUTE, AuthorizedIdentity.class))) + .setSigningKeyResolver(new SigningKeyResolver() + { + // interface uses raw types and this can not be fixed here + @SuppressWarnings("rawtypes") + @Override + public Key resolveSigningKey(JwsHeader header, Claims claims) + { + return keyLoader.apply(header); + } + + @SuppressWarnings("rawtypes") + @Override + public Key resolveSigningKey(JwsHeader header, String plaintext) + { + return keyLoader.apply(header); + } + }) + .build(); + + if (config.getRequiredIssuer() != null) { + jwtParser.requireIssuer(config.getRequiredIssuer()); + } + if (config.getRequiredAudience() != null) { + jwtParser.requireAudience(config.getRequiredAudience()); + } + this.jwtParser = jwtParser; + } + + @Override + public Principal authenticate(HttpServletRequest request) + throws AuthenticationException + { + String header = nullToEmpty(request.getHeader(AUTHORIZATION)); + + int space = header.indexOf(' '); + if ((space < 0) || !header.substring(0, space).equalsIgnoreCase("bearer")) { + throw needAuthentication(null); + } + String token = header.substring(space + 1).trim(); + if (token.isEmpty()) { + throw needAuthentication(null); + } + + try { + Jws claimsJws = jwtParser.parseClaimsJws(token); + + AuthorizedIdentity authorizedIdentity = claimsJws.getBody().get(AUTHORIZED_IDENTITY_ATTRIBUTE, AuthorizedIdentity.class); + if (authorizedIdentity != null) { + setAuthorizedIdentity(request, authorizedIdentity); + } + + String subject = claimsJws.getBody().getSubject(); + return new BasicPrincipal(subject); + } + catch (JwtException e) { + throw needAuthentication(e.getMessage()); + } + catch (RuntimeException e) { + throw new RuntimeException("Authentication error", e); + } + } + + private static AuthenticationException needAuthentication(String message) + { + return new AuthenticationException(message, "Bearer realm=\"Presto\", token_type=\"JWT\""); + } + + private static class StaticKeyLoader + implements Function, Key> + { + private final LoadedKey key; + + public StaticKeyLoader(String keyFile) + { + requireNonNull(keyFile, "keyFile is null"); + checkArgument(!keyFile.contains(KEY_ID_VARIABLE)); + this.key = loadKeyFile(new File(keyFile)); + } + + @Override + public Key apply(JwsHeader header) + { + SignatureAlgorithm algorithm = SignatureAlgorithm.forName(header.getAlgorithm()); + return key.getKey(algorithm); + } + } + + private static class DynamicKeyLoader + implements Function, Key> + { + private final String keyFile; + private final ConcurrentMap keys = new ConcurrentHashMap<>(); + + public DynamicKeyLoader(String keyFile) + { + requireNonNull(keyFile, "keyFile is null"); + checkArgument(keyFile.contains(KEY_ID_VARIABLE)); + this.keyFile = keyFile; + } + + @Override + public Key apply(JwsHeader header) + { + String keyId = getKeyId(header); + SignatureAlgorithm algorithm = SignatureAlgorithm.forName(header.getAlgorithm()); + return keys.computeIfAbsent(keyId, this::loadKey).getKey(algorithm); + } + + private static String getKeyId(JwsHeader header) + { + String keyId = header.getKeyId(); + if (keyId == null) { + // allow for migration from system not using kid + return DEFAULT_KEY; + } + keyId = INVALID_KID_CHARS.replaceFrom(keyId, '_'); + return keyId; + } + + private LoadedKey loadKey(String keyId) + { + return loadKeyFile(new File(keyFile.replace(KEY_ID_VARIABLE, keyId))); + } + } + + private static LoadedKey loadKeyFile(File file) + { + if (!file.canRead()) { + throw new SignatureException("Unknown signing key ID"); + } + + // try to load the key as a PEM encoded public key + try { + return new LoadedKey(PemReader.loadPublicKey(file)); + } + catch (Exception ignored) { + } + + // try to load the key as a base64 encoded HMAC key + try { + String base64Key = asCharSource(file, US_ASCII).read(); + byte[] rawKey = getMimeDecoder().decode(base64Key.getBytes(US_ASCII)); + return new LoadedKey(rawKey); + } + catch (IOException ignored) { + } + + throw new SignatureException("Unknown signing key id"); + } + + private static class LoadedKey + { + private final Key publicKey; + private final byte[] hmacKey; + + public LoadedKey(Key publicKey) + { + this.publicKey = requireNonNull(publicKey, "publicKey is null"); + this.hmacKey = null; + } + + public LoadedKey(byte[] hmacKey) + { + this.hmacKey = requireNonNull(hmacKey, "hmacKey is null"); + this.publicKey = null; + } + + public Key getKey(SignatureAlgorithm algorithm) + { + if (algorithm.isHmac()) { + if (hmacKey == null) { + throw new UnsupportedJwtException(format("JWT is signed with %s, but no HMAC key is configured", algorithm)); + } + return new SecretKeySpec(hmacKey, algorithm.getJcaName()); + } + + if (publicKey == null) { + throw new UnsupportedJwtException(format("JWT is signed with %s, but no key is configured", algorithm)); + } + return publicKey; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/JsonWebTokenConfig.java b/presto-main/src/main/java/com/facebook/presto/server/security/JsonWebTokenConfig.java new file mode 100644 index 0000000000000..1eaddacf1a1cb --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/JsonWebTokenConfig.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security; + +import com.facebook.airlift.configuration.Config; + +import javax.validation.constraints.NotNull; + +public class JsonWebTokenConfig +{ + private String keyFile; + private String requiredIssuer; + private String requiredAudience; + + @NotNull + public String getKeyFile() + { + return keyFile; + } + + @Config("http.authentication.jwt.key-file") + public JsonWebTokenConfig setKeyFile(String keyFile) + { + this.keyFile = keyFile; + return this; + } + + public String getRequiredIssuer() + { + return requiredIssuer; + } + + @Config("http.authentication.jwt.required-issuer") + public JsonWebTokenConfig setRequiredIssuer(String requiredIssuer) + { + this.requiredIssuer = requiredIssuer; + return this; + } + + public String getRequiredAudience() + { + return requiredAudience; + } + + @Config("http.authentication.jwt.required-audience") + public JsonWebTokenConfig setRequiredAudience(String requiredAudience) + { + this.requiredAudience = requiredAudience; + return this; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/ServerSecurityModule.java b/presto-main/src/main/java/com/facebook/presto/server/security/ServerSecurityModule.java index 6da0ab45778f0..ebeaf5e8b6cfa 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/ServerSecurityModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/ServerSecurityModule.java @@ -16,8 +16,6 @@ import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; import com.facebook.airlift.http.server.Authenticator; import com.facebook.airlift.http.server.CertificateAuthenticator; -import com.facebook.airlift.http.server.JsonWebTokenAuthenticator; -import com.facebook.airlift.http.server.JsonWebTokenConfig; import com.facebook.airlift.http.server.KerberosAuthenticator; import com.facebook.airlift.http.server.KerberosConfig; import com.facebook.presto.server.security.SecurityConfig.AuthenticationType; diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/ServletSecurityUtils.java b/presto-main/src/main/java/com/facebook/presto/server/security/ServletSecurityUtils.java new file mode 100644 index 0000000000000..950f5a19d1ccc --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/ServletSecurityUtils.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security; + +import com.facebook.presto.spi.security.AuthorizedIdentity; + +import javax.servlet.http.HttpServletRequest; + +import java.util.Optional; + +public class ServletSecurityUtils +{ + public static final String AUTHORIZED_IDENTITY_ATTRIBUTE = "presto.authorized-identity"; + + private ServletSecurityUtils() {} + + public static void setAuthorizedIdentity(HttpServletRequest servletRequest, AuthorizedIdentity authorizedIdentity) + { + servletRequest.setAttribute(AUTHORIZED_IDENTITY_ATTRIBUTE, authorizedIdentity); + } + + public static Optional authorizedIdentity(HttpServletRequest servletRequest) + { + return Optional.ofNullable((AuthorizedIdentity) servletRequest.getAttribute(AUTHORIZED_IDENTITY_ATTRIBUTE)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/MockHttpServletRequest.java b/presto-main/src/test/java/com/facebook/presto/server/MockHttpServletRequest.java index ff27812a26a7a..d43b34dd4eef2 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/MockHttpServletRequest.java +++ b/presto-main/src/test/java/com/facebook/presto/server/MockHttpServletRequest.java @@ -14,7 +14,6 @@ package com.facebook.presto.server; import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ListMultimap; import javax.servlet.AsyncContext; @@ -35,6 +34,7 @@ import java.security.Principal; import java.util.Collection; import java.util.Enumeration; +import java.util.HashMap; import java.util.Locale; import java.util.Map; @@ -53,7 +53,7 @@ public MockHttpServletRequest(ListMultimap headers, String remot { this.headers = ImmutableListMultimap.copyOf(requireNonNull(headers, "headers is null")); this.remoteAddress = requireNonNull(remoteAddress, "remoteAddress is null"); - this.attributes = ImmutableMap.copyOf(requireNonNull(attributes, "attributes is null")); + this.attributes = new HashMap<>(requireNonNull(attributes, "attributes is null")); } @Override @@ -371,7 +371,7 @@ public String getRemoteHost() @Override public void setAttribute(String name, Object o) { - throw new UnsupportedOperationException(); + attributes.put(name, o); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionContext.java b/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionContext.java index ee633cf2aeed8..221950a2de9d3 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionContext.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionContext.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.function.RoutineCharacteristics; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.security.AuthorizedIdentity; import com.facebook.presto.spi.security.Identity; import com.facebook.presto.spi.security.SelectedRole; import com.facebook.presto.sql.parser.IdentifierSymbol; @@ -55,6 +56,7 @@ import static com.facebook.presto.client.PrestoHeaders.PRESTO_USER; import static com.facebook.presto.common.type.StandardTypes.INTEGER; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.server.security.ServletSecurityUtils.AUTHORIZED_IDENTITY_ATTRIBUTE; import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT; @@ -211,6 +213,24 @@ public void testExtraCredentials() .build()); } + @Test + public void testAuthorizedIdentity() + { + AuthorizedIdentity authorizedIdentity = new AuthorizedIdentity("username", "reasonForSelect", false); + HttpServletRequest request = new MockHttpServletRequest( + ImmutableListMultimap.builder() + .put(PRESTO_USER, "testUser") + .put(PRESTO_SOURCE, "testSource") + .put(PRESTO_CATALOG, "testCatalog") + .put(PRESTO_SCHEMA, "testSchema") + .build(), + "testRemote", + ImmutableMap.of(AUTHORIZED_IDENTITY_ATTRIBUTE, authorizedIdentity)); + + HttpRequestSessionContext context = new HttpRequestSessionContext(request, new SqlParserOptions()); + assertEquals(context.getAuthorizedIdentity(), Optional.of(authorizedIdentity)); + } + protected static String urlEncode(String value) { try { diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestQuerySessionSupplier.java b/presto-main/src/test/java/com/facebook/presto/server/TestQuerySessionSupplier.java index 38265f3f46f37..e8c4d37ca8c79 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestQuerySessionSupplier.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestQuerySessionSupplier.java @@ -18,11 +18,13 @@ import com.facebook.presto.common.type.TimeZoneNotSupportedException; import com.facebook.presto.execution.warnings.WarningCollectorFactory; import com.facebook.presto.metadata.SessionPropertyManager; +import com.facebook.presto.server.security.SecurityConfig; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.security.AllowAllAccessControl; +import com.facebook.presto.spi.security.AuthorizedIdentity; import com.facebook.presto.sql.SqlEnvironmentConfig; import com.facebook.presto.sql.parser.SqlParserOptions; import com.google.common.collect.ImmutableListMultimap; @@ -33,7 +35,6 @@ import javax.servlet.http.HttpServletRequest; import java.util.Locale; -import java.util.Optional; import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.SystemSessionProperties.HASH_PARTITION_COUNT; @@ -54,6 +55,7 @@ import static com.facebook.presto.server.TestHttpRequestSessionContext.createFunctionAdd; import static com.facebook.presto.server.TestHttpRequestSessionContext.createSqlFunctionIdAdd; import static com.facebook.presto.server.TestHttpRequestSessionContext.urlEncode; +import static com.facebook.presto.server.security.ServletSecurityUtils.AUTHORIZED_IDENTITY_ATTRIBUTE; import static com.facebook.presto.transaction.InMemoryTransactionManager.createTestTransactionManager; import static java.lang.String.format; import static org.testng.Assert.assertEquals; @@ -64,6 +66,7 @@ public class TestQuerySessionSupplier private static final SqlInvokedFunction SQL_FUNCTION_ADD = createFunctionAdd(); private static final String SERIALIZED_SQL_FUNCTION_ID_ADD = jsonCodec(SqlFunctionId.class).toJson(SQL_FUNCTION_ID_ADD); private static final String SERIALIZED_SQL_FUNCTION_ADD = jsonCodec(SqlInvokedFunction.class).toJson(SQL_FUNCTION_ADD); + private static final AuthorizedIdentity AUTHORIZED_IDENTITY = new AuthorizedIdentity("userName", "reasonForSelect", false); private static final HttpServletRequest TEST_REQUEST = new MockHttpServletRequest( ImmutableListMultimap.builder() @@ -81,7 +84,7 @@ public class TestQuerySessionSupplier .put(PRESTO_SESSION_FUNCTION, format("%s=%s", urlEncode(SERIALIZED_SQL_FUNCTION_ID_ADD), urlEncode(SERIALIZED_SQL_FUNCTION_ADD))) .build(), "testRemote", - ImmutableMap.of()); + ImmutableMap.of(AUTHORIZED_IDENTITY_ATTRIBUTE, AUTHORIZED_IDENTITY)); @Test public void testCreateSession() @@ -91,7 +94,8 @@ public void testCreateSession() createTestTransactionManager(), new AllowAllAccessControl(), new SessionPropertyManager(), - new SqlEnvironmentConfig()); + new SqlEnvironmentConfig(), + new SecurityConfig()); WarningCollectorFactory warningCollectorFactory = new WarningCollectorFactory() { @Override @@ -100,7 +104,7 @@ public WarningCollector create(WarningHandlingLevel warningHandlingLevel) return WarningCollector.NOOP; } }; - Session session = sessionSupplier.createSession(new QueryId("test_query_id"), context, warningCollectorFactory, Optional.empty()); + Session session = sessionSupplier.createSession(new QueryId("test_query_id"), context, warningCollectorFactory); assertEquals(session.getQueryId(), new QueryId("test_query_id")); assertEquals(session.getUser(), "testUser"); @@ -122,6 +126,8 @@ public WarningCollector create(WarningHandlingLevel warningHandlingLevel) .put("query2", "select * from bar") .build()); assertEquals(session.getSessionFunctions(), ImmutableMap.of(SQL_FUNCTION_ID_ADD, SQL_FUNCTION_ADD)); + assertEquals(session.getIdentity().getSelectedUser().get(), AUTHORIZED_IDENTITY.getUserName()); + assertEquals(session.getIdentity().getReasonForSelect(), AUTHORIZED_IDENTITY.getReasonForSelect()); } @Test @@ -162,7 +168,8 @@ public void testInvalidTimeZone() createTestTransactionManager(), new AllowAllAccessControl(), new SessionPropertyManager(), - new SqlEnvironmentConfig()); + new SqlEnvironmentConfig(), + new SecurityConfig()); WarningCollectorFactory warningCollectorFactory = new WarningCollectorFactory() { @Override @@ -171,6 +178,6 @@ public WarningCollector create(WarningHandlingLevel warningHandlingLevel) return WarningCollector.NOOP; } }; - sessionSupplier.createSession(new QueryId("test_query_id"), context, warningCollectorFactory, Optional.empty()); + sessionSupplier.createSession(new QueryId("test_query_id"), context, warningCollectorFactory); } } diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/TestJsonWebTokenAuthenticator.java b/presto-main/src/test/java/com/facebook/presto/server/security/TestJsonWebTokenAuthenticator.java new file mode 100644 index 0000000000000..761ac37033c91 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/TestJsonWebTokenAuthenticator.java @@ -0,0 +1,101 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security; + +import com.facebook.airlift.http.server.AuthenticationException; +import com.facebook.presto.server.MockHttpServletRequest; +import com.facebook.presto.spi.security.AuthorizedIdentity; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Files; +import io.jsonwebtoken.Jwts; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import javax.servlet.http.HttpServletRequest; + +import java.io.IOException; +import java.nio.file.Path; +import java.security.Principal; + +import static com.facebook.presto.server.security.ServletSecurityUtils.AUTHORIZED_IDENTITY_ATTRIBUTE; +import static com.facebook.presto.server.security.ServletSecurityUtils.authorizedIdentity; +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static com.google.common.io.Files.createTempDir; +import static com.google.common.io.MoreFiles.deleteRecursively; +import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; +import static com.google.common.net.HttpHeaders.AUTHORIZATION; +import static io.jsonwebtoken.JwsHeader.KEY_ID; +import static io.jsonwebtoken.SignatureAlgorithm.HS256; +import static io.jsonwebtoken.security.Keys.secretKeyFor; +import static java.nio.file.Files.readAllBytes; +import static java.util.Base64.getMimeDecoder; +import static java.util.Base64.getMimeEncoder; + +public class TestJsonWebTokenAuthenticator +{ + private static final String KEY_ID_FOO = "foo"; + private static final String TEST_PRINCIPAL = "testPrincipal"; + + private Path temporaryDirectory; + private Path keyFile; + private JsonWebTokenConfig jsonWebTokenConfig; + + @BeforeTest + public void setup() + throws IOException + { + temporaryDirectory = createTempDir().toPath(); + keyFile = temporaryDirectory.resolve(KEY_ID_FOO + ".key"); + byte[] key = getMimeEncoder().encode(secretKeyFor(HS256).getEncoded()); + Files.write(key, keyFile.toFile()); + jsonWebTokenConfig = new JsonWebTokenConfig().setKeyFile(keyFile.toAbsolutePath().toString()); + } + + @AfterTest(alwaysRun = true) + public void cleanup() + throws IOException + { + deleteRecursively(temporaryDirectory, ALLOW_INSECURE); + } + + @Test + public void testJsonWebTokenWithAuthorizedUserClaim() + throws IOException, AuthenticationException + { + AuthorizedIdentity authorizedIdentity = new AuthorizedIdentity("user", "reasonForSelect", false); + String jsonWebToken = createJsonWebToken(keyFile, TEST_PRINCIPAL, authorizedIdentity); + HttpServletRequest request = new MockHttpServletRequest( + ImmutableListMultimap.of(AUTHORIZATION, "Bearer " + jsonWebToken), + "remoteAddress", + ImmutableMap.of()); + Principal principal = new JsonWebTokenAuthenticator(jsonWebTokenConfig).authenticate(request); + + assertEquals(principal.getName(), TEST_PRINCIPAL); + assertEquals(authorizedIdentity(request).get(), authorizedIdentity); + } + + private static String createJsonWebToken(Path keyFile, String principal, AuthorizedIdentity authorizedIdentity) + throws IOException + { + byte[] key = getMimeDecoder().decode(readAllBytes(keyFile.toAbsolutePath())); + return Jwts.builder() + .signWith(HS256, key) + .setHeaderParam(KEY_ID, KEY_ID_FOO) + .setSubject(principal) + .claim(AUTHORIZED_IDENTITY_ATTRIBUTE, authorizedIdentity) + .compact(); + } +} diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkQueryExecutionFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkQueryExecutionFactory.java index a89414633bd26..0cbbf080b21ed 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkQueryExecutionFactory.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkQueryExecutionFactory.java @@ -90,7 +90,6 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.security.AccessControl; -import com.facebook.presto.spi.security.AuthorizedIdentity; import com.facebook.presto.spi.storage.TempStorage; import com.facebook.presto.sql.analyzer.BuiltInQueryPreparer; import com.facebook.presto.sql.analyzer.BuiltInQueryPreparer.BuiltInPreparedQuery; @@ -139,8 +138,6 @@ import static com.facebook.presto.execution.QueryState.FAILED; import static com.facebook.presto.execution.QueryState.PLANNING; import static com.facebook.presto.execution.StageInfo.getAllStages; -import static com.facebook.presto.security.AccessControlUtils.checkPermissions; -import static com.facebook.presto.security.AccessControlUtils.getAuthorizedIdentity; import static com.facebook.presto.server.protocol.QueryResourceUtil.toStatementStats; import static com.facebook.presto.spark.PrestoSparkSessionProperties.isAdaptiveQueryExecutionEnabled; import static com.facebook.presto.spark.SparkErrorCode.MALFORMED_QUERY_FILE; @@ -612,13 +609,7 @@ public IPrestoSparkQueryExecution create( credentialsProviders, authenticatorProviders); - // check permissions if needed - checkPermissions(accessControl, securityConfig, queryId, sessionContext); - - // get authorized identity if possible - Optional authorizedIdentity = getAuthorizedIdentity(accessControl, securityConfig, queryId, sessionContext); - - Session session = sessionSupplier.createSession(queryId, sessionContext, warningCollectorFactory, authorizedIdentity); + Session session = sessionSupplier.createSession(queryId, sessionContext, warningCollectorFactory); session = sessionPropertyDefaults.newSessionWithDefaultProperties(session, Optional.empty(), Optional.empty()); if (!executionStrategies.isEmpty()) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/security/AuthorizedIdentity.java b/presto-spi/src/main/java/com/facebook/presto/spi/security/AuthorizedIdentity.java index fa483419cf123..9b0e5fa5da9ac 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/security/AuthorizedIdentity.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/security/AuthorizedIdentity.java @@ -13,6 +13,10 @@ */ package com.facebook.presto.spi.security; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -23,13 +27,18 @@ public class AuthorizedIdentity private final Optional reasonForSelect; private final Optional delegationCheckResult; - public AuthorizedIdentity(String userName, String reasonForSelect, Boolean delegationCheckResult) + @JsonCreator + public AuthorizedIdentity( + @JsonProperty("userName") String userName, + @JsonProperty("reasonForSelect") String reasonForSelect, + @JsonProperty("delegationCheckResult") Boolean delegationCheckResult) { this.userName = requireNonNull(userName, "userName is null"); this.reasonForSelect = Optional.ofNullable(reasonForSelect); this.delegationCheckResult = Optional.ofNullable(delegationCheckResult); } + @JsonProperty("userName") public String getUserName() { return userName; @@ -40,8 +49,39 @@ public Optional getReasonForSelect() return reasonForSelect; } + @JsonProperty("reasonForSelect") + public String getReasonForSelectValue() + { + return reasonForSelect.orElse(null); + } + public Optional getDelegationCheckResult() { return delegationCheckResult; } + + @JsonProperty("delegationCheckResult") + public Boolean getDelegationCheckResultValue() + { + return delegationCheckResult.orElse(null); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + AuthorizedIdentity that = (AuthorizedIdentity) o; + return Objects.equals(userName, that.userName) && Objects.equals(reasonForSelect, that.reasonForSelect) && Objects.equals(delegationCheckResult, that.delegationCheckResult); + } + + @Override + public int hashCode() + { + return Objects.hash(userName, reasonForSelect, delegationCheckResult); + } }