diff --git a/core/trino-main/src/main/java/io/trino/server/security/jwt/InvalidClaimException.java b/core/trino-main/src/main/java/io/trino/server/security/jwt/InvalidClaimException.java new file mode 100644 index 000000000000..a8042489982c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/security/jwt/InvalidClaimException.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.security.jwt; + +import io.jsonwebtoken.JwtException; + +public class InvalidClaimException + extends JwtException +{ + public InvalidClaimException(String message) + { + super(message); + } + + public InvalidClaimException(String message, Throwable cause) + { + super(message, cause); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java index 83e4f98f4cad..d1cda022b68a 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java +++ b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java @@ -13,6 +13,7 @@ */ package io.trino.server.security.jwt; +import io.jsonwebtoken.Claims; import io.jsonwebtoken.JwtParser; import io.jsonwebtoken.JwtParserBuilder; import io.jsonwebtoken.SigningKeyResolver; @@ -26,10 +27,13 @@ import javax.inject.Inject; import javax.ws.rs.container.ContainerRequestContext; +import java.util.Collection; import java.util.Optional; +import static io.jsonwebtoken.Claims.AUDIENCE; import static io.trino.server.security.UserMapping.createUserMapping; import static io.trino.server.security.jwt.JwtUtil.newJwtParserBuilder; +import static java.lang.String.format; public class JwtAuthenticator extends AbstractBearerAuthenticator @@ -37,11 +41,13 @@ public class JwtAuthenticator private final JwtParser jwtParser; private final String principalField; private final UserMapping userMapping; + private final Optional requiredAudience; @Inject public JwtAuthenticator(JwtAuthenticatorConfig config, @ForJwt SigningKeyResolver signingKeyResolver) { principalField = config.getPrincipalField(); + requiredAudience = Optional.ofNullable(config.getRequiredAudience()); JwtParserBuilder jwtParser = newJwtParserBuilder() .setSigningKeyResolver(signingKeyResolver); @@ -49,9 +55,6 @@ public JwtAuthenticator(JwtAuthenticatorConfig config, @ForJwt SigningKeyResolve if (config.getRequiredIssuer() != null) { jwtParser.requireIssuer(config.getRequiredIssuer()); } - if (config.getRequiredAudience() != null) { - jwtParser.requireAudience(config.getRequiredAudience()); - } this.jwtParser = jwtParser.build(); userMapping = createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile()); } @@ -60,9 +63,10 @@ public JwtAuthenticator(JwtAuthenticatorConfig config, @ForJwt SigningKeyResolve protected Optional createIdentity(String token) throws UserMappingException { - Optional principal = Optional.ofNullable(jwtParser.parseClaimsJws(token) - .getBody() - .get(principalField, String.class)); + Claims claims = jwtParser.parseClaimsJws(token).getBody(); + validateAudience(claims); + + Optional principal = Optional.ofNullable(claims.get(principalField, String.class)); if (principal.isEmpty()) { return Optional.empty(); } @@ -71,6 +75,32 @@ protected Optional createIdentity(String token) .build()); } + private void validateAudience(Claims claims) + { + if (requiredAudience.isEmpty()) { + return; + } + + Object tokenAudience = claims.get(AUDIENCE); + if (tokenAudience == null) { + throw new InvalidClaimException(format("Expected %s claim to be: %s, but was not present in the JWT claims.", AUDIENCE, requiredAudience.get())); + } + + if (tokenAudience instanceof String) { + if (!requiredAudience.get().equals((String) tokenAudience)) { + throw new InvalidClaimException(format("Invalid Audience: %s. Allowed audiences: %s", tokenAudience, requiredAudience.get())); + } + } + else if (tokenAudience instanceof Collection) { + if (((Collection) tokenAudience).stream().map(String.class::cast).noneMatch(aud -> requiredAudience.get().equals(aud))) { + throw new InvalidClaimException(format("Invalid Audience: %s. Allowed audiences: %s", tokenAudience, requiredAudience.get())); + } + } + else { + throw new InvalidClaimException(format("Invalid Audience: %s", tokenAudience)); + } + } + @Override protected AuthenticationException needAuthentication(ContainerRequestContext request, Optional currentToken, String message) { diff --git a/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java b/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java index a62dfb32875e..bf3c3ffef554 100644 --- a/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java +++ b/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java @@ -97,6 +97,7 @@ import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static io.airlift.jaxrs.JaxrsBinder.jaxrsBinder; +import static io.jsonwebtoken.Claims.AUDIENCE; import static io.jsonwebtoken.security.Keys.hmacShaKeyFor; import static io.trino.client.OkHttpUtil.setupSsl; import static io.trino.client.ProtocolHeaders.TRINO_HEADERS; @@ -148,6 +149,9 @@ public class TestResourceSecurity private static final String HMAC_KEY = Resources.getResource("hmac_key.txt").getPath(); private static final String JWK_KEY_ID = "test-rsa"; private static final String GROUPS_CLAIM = "groups"; + private static final String TRINO_AUDIENCE = "trino-client"; + private static final String ADDITIONAL_AUDIENCE = "https://external-service.com"; + private static final String UNTRUSTED_CLIENT_AUDIENCE = "https://untrusted.com"; private static final PrivateKey JWK_PRIVATE_KEY; private static final PublicKey JWK_PUBLIC_KEY; private static final ObjectMapper json = new ObjectMapper(); @@ -460,11 +464,13 @@ public void testCertAuthenticator() public void testJwtAuthenticator() throws Exception { - verifyJwtAuthenticator(Optional.empty()); - verifyJwtAuthenticator(Optional.of("custom-principal")); + verifyJwtAuthenticator(Optional.empty(), Optional.empty()); + verifyJwtAuthenticator(Optional.of("custom-principal"), Optional.empty()); + verifyJwtAuthenticator(Optional.empty(), Optional.of(TRINO_AUDIENCE)); + verifyJwtAuthenticator(Optional.empty(), Optional.of(ImmutableList.of(TRINO_AUDIENCE, ADDITIONAL_AUDIENCE))); } - private void verifyJwtAuthenticator(Optional principalField) + private void verifyJwtAuthenticator(Optional principalField, Optional audience) throws Exception { try (TestingTrinoServer server = TestingTrinoServer.builder() @@ -473,6 +479,7 @@ private void verifyJwtAuthenticator(Optional principalField) .put("http-server.authentication.type", "jwt") .put("http-server.authentication.jwt.key-file", HMAC_KEY) .put("http-server.authentication.jwt.principal-field", principalField.orElse("sub")) + .put("http-server.authentication.jwt.required-audience", TRINO_AUDIENCE) .buildOrThrow()) .build()) { server.getInstance(Key.get(AccessControlManager.class)).addSystemAccessControl(TestSystemAccessControl.NO_IMPERSONATION); @@ -485,10 +492,17 @@ private void verifyJwtAuthenticator(Optional principalField) .signWith(hmac) .setExpiration(Date.from(ZonedDateTime.now().plusMinutes(5).toInstant())); if (principalField.isPresent()) { - tokenBuilder.claim(principalField.get(), "test-user"); + tokenBuilder.claim(principalField.get(), TEST_USER); } else { - tokenBuilder.setSubject("test-user"); + tokenBuilder.setSubject(TEST_USER); + } + + if (audience.isPresent()) { + tokenBuilder.claim(AUDIENCE, audience.get()); + } + else { + tokenBuilder.setAudience(TRINO_AUDIENCE); } String token = tokenBuilder.compact(); @@ -538,6 +552,78 @@ public void testJwtWithJwkAuthenticator() } } + @Test + public void testJwtAuthenticatorWithInvalidAudience() + throws Exception + { + try (TestingTrinoServer server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .putAll(SECURE_PROPERTIES) + .put("http-server.authentication.type", "jwt") + .put("http-server.authentication.jwt.key-file", HMAC_KEY) + .put("http-server.authentication.jwt.required-audience", TRINO_AUDIENCE) + .buildOrThrow()) + .build()) { + server.getInstance(Key.get(AccessControlManager.class)).addSystemAccessControl(TestSystemAccessControl.NO_IMPERSONATION); + HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); + + SecretKey hmac = hmacShaKeyFor(Base64.getDecoder().decode(Files.readString(Paths.get(HMAC_KEY)).trim())); + JwtBuilder tokenBuilder = newJwtBuilder() + .signWith(hmac) + .setExpiration(Date.from(ZonedDateTime.now().plusMinutes(5).toInstant())) + .claim(AUDIENCE, ImmutableList.of(ADDITIONAL_AUDIENCE, UNTRUSTED_CLIENT_AUDIENCE)); + String token = tokenBuilder.compact(); + + OkHttpClient clientWithJwt = this.client.newBuilder() + .build(); + assertResponseCode(clientWithJwt, getAuthorizedUserLocation(httpServerInfo.getHttpsUri()), SC_UNAUTHORIZED, Headers.of(AUTHORIZATION, "Bearer " + token)); + } + } + + @Test + public void testJwtAuthenticatorWithNoRequiredAudience() + throws Exception + { + verifyJwtAuthenticatorWithoutRequiredAudience(Optional.empty()); + verifyJwtAuthenticatorWithoutRequiredAudience(Optional.of(ImmutableList.of(TRINO_AUDIENCE, ADDITIONAL_AUDIENCE))); + } + + private void verifyJwtAuthenticatorWithoutRequiredAudience(Optional audience) + throws Exception + { + try (TestingTrinoServer server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .putAll(SECURE_PROPERTIES) + .put("http-server.authentication.type", "jwt") + .put("http-server.authentication.jwt.key-file", HMAC_KEY) + .buildOrThrow()) + .build()) { + server.getInstance(Key.get(AccessControlManager.class)).addSystemAccessControl(TestSystemAccessControl.NO_IMPERSONATION); + HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); + + assertAuthenticationDisabled(httpServerInfo.getHttpUri()); + + SecretKey hmac = hmacShaKeyFor(Base64.getDecoder().decode(Files.readString(Paths.get(HMAC_KEY)).trim())); + JwtBuilder tokenBuilder = newJwtBuilder() + .signWith(hmac) + .setExpiration(Date.from(ZonedDateTime.now().plusMinutes(5).toInstant())) + .setSubject(TEST_USER); + + if (audience.isPresent()) { + tokenBuilder.claim(AUDIENCE, audience.get()); + } + + String token = tokenBuilder.compact(); + + OkHttpClient clientWithJwt = client.newBuilder() + .authenticator((route, response) -> response.request().newBuilder() + .header(AUTHORIZATION, "Bearer " + token) + .build()) + .build(); + assertAuthenticationAutomatic(httpServerInfo.getHttpsUri(), clientWithJwt); + } + } + @Test public void testOAuth2Authenticator() throws Exception