diff --git a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2Authenticator.java b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2Authenticator.java index e96832660c54..60cd65e021ee 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2Authenticator.java +++ b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2Authenticator.java @@ -14,6 +14,7 @@ package io.trino.server.security.oauth2; import com.google.common.collect.ImmutableSet; +import io.airlift.log.Logger; import io.trino.server.security.AbstractBearerAuthenticator; import io.trino.server.security.AuthenticationException; import io.trino.server.security.UserMapping; @@ -42,6 +43,7 @@ public class OAuth2Authenticator extends AbstractBearerAuthenticator { + private static final Logger log = Logger.get(OAuth2Authenticator.class); private final OAuth2Client client; private final String principalField; private final Optional groupsField; @@ -64,7 +66,12 @@ public OAuth2Authenticator(OAuth2Client client, OAuth2Config config, TokenRefres protected Optional createIdentity(String token) throws UserMappingException { - TokenPair tokenPair = tokenPairSerializer.deserialize(token); + Optional deserializeToken = deserializeToken(token); + if (deserializeToken.isEmpty()) { + return Optional.empty(); + } + + TokenPair tokenPair = deserializeToken.get(); if (tokenPair.getExpiration().before(Date.from(Instant.now()))) { return Optional.empty(); } @@ -80,11 +87,22 @@ protected Optional createIdentity(String token) return Optional.of(builder.build()); } + private Optional deserializeToken(String token) + { + try { + return Optional.of(tokenPairSerializer.deserialize(token)); + } + catch (RuntimeException ex) { + log.debug(ex, "Failed to deserialize token"); + return Optional.empty(); + } + } + @Override protected AuthenticationException needAuthentication(ContainerRequestContext request, Optional currentToken, String message) { return currentToken - .map(tokenPairSerializer::deserialize) + .flatMap(this::deserializeToken) .flatMap(tokenRefresher::refreshToken) .map(refreshId -> request.getUriInfo().getBaseUri().resolve(getTokenUri(refreshId))) .map(tokenUri -> new AuthenticationException(message, format("Bearer x_token_server=\"%s\"", tokenUri))) diff --git a/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java b/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java index a62dfb32875e..d97558d3e94d 100644 --- a/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java +++ b/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java @@ -816,6 +816,48 @@ public void testJwtAndOAuth2AuthenticatorsSeparation() } } + @Test + public void testJwtWithRefreshTokensForOAuth2Enabled() + throws Exception + { + TestingHttpServer jwkServer = createTestingJwkServer(); + jwkServer.start(); + try (TokenServer tokenServer = new TokenServer(Optional.empty()); + TestingTrinoServer server = TestingTrinoServer.builder() + .setProperties( + ImmutableMap.builder() + .putAll(SECURE_PROPERTIES) + .put("http-server.authentication.type", "oauth2,jwt") + .put("http-server.authentication.jwt.key-file", jwkServer.getBaseUrl().toString()) + .putAll(ImmutableMap.builder() + .putAll(getOAuth2Properties(tokenServer)) + .put("http-server.authentication.oauth2.refresh-tokens", "true") + .buildOrThrow()) + .put("web-ui.enabled", "true") + .buildOrThrow()) + .setAdditionalModule(oauth2Module(tokenServer)) + .build()) { + server.getInstance(Key.get(AccessControlManager.class)).addSystemAccessControl(TestSystemAccessControl.NO_IMPERSONATION); + HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); + + assertAuthenticationDisabled(httpServerInfo.getHttpUri()); + + String token = newJwtBuilder() + .signWith(JWK_PRIVATE_KEY) + .setHeaderParam(JwsHeader.KEY_ID, JWK_KEY_ID) + .setSubject("test-user") + .setExpiration(Date.from(ZonedDateTime.now().plusMinutes(5).toInstant())) + .compact(); + + OkHttpClient clientWithJwt = client.newBuilder() + .authenticator((route, response) -> response.request().newBuilder() + .header(AUTHORIZATION, "Bearer " + token) + .build()) + .build(); + assertAuthenticationAutomatic(httpServerInfo.getHttpsUri(), clientWithJwt); + } + } + private static Module oauth2Module(TokenServer tokenServer) { return binder -> {