Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.server.security.oauth2;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a typo in the commit message.

user couldn't log in by using standard jwt token.

Can you please explain why in the commit message too?


import com.google.common.collect.ImmutableSet;
import io.airlift.log.Logger;
import io.trino.server.security.AbstractBearerAuthenticator;
import io.trino.server.security.AuthenticationException;
import io.trino.server.security.UserMapping;
Expand Down Expand Up @@ -42,6 +43,7 @@
public class OAuth2Authenticator
extends AbstractBearerAuthenticator
{
private static final Logger log = Logger.get(OAuth2Authenticator.class);
private final OAuth2Client client;
private final String principalField;
private final Optional<String> groupsField;
Expand All @@ -64,7 +66,12 @@ public OAuth2Authenticator(OAuth2Client client, OAuth2Config config, TokenRefres
protected Optional<Identity> createIdentity(String token)
throws UserMappingException
{
TokenPair tokenPair = tokenPairSerializer.deserialize(token);
Optional<TokenPair> deserializeToken = deserializeToken(token);
if (deserializeToken.isEmpty()) {
return Optional.empty();
}

TokenPair tokenPair = deserializeToken.get();
if (tokenPair.getExpiration().before(Date.from(Instant.now()))) {
return Optional.empty();
}
Expand All @@ -80,11 +87,22 @@ protected Optional<Identity> createIdentity(String token)
return Optional.of(builder.build());
}

private Optional<TokenPair> deserializeToken(String token)
{
try {
return Optional.of(tokenPairSerializer.deserialize(token));
}
catch (RuntimeException ex) {
Copy link
Copy Markdown
Contributor

@andrewdibiasio6 andrewdibiasio6 Aug 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to just catch java.lang.IllegalArgumentException for now?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we narrow down the catching ? What if the token fails because it is damaged or expired ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so this catch is so wide on purpose. Right now there are plenty of reasons why token deserialization might fail. We might be having invalid token (on purpose - like in this bug scenario), token might have expired - meaning that both access token and refresh token are either expired or soon to be expired, encryption keys might have been rolled out (after restart) etc. All of the above suggests that we need to discard the token, but we should not request any specific handling of these errors. Rather I think we should just discard this token and issue another challenge to the client - so business as usual. I could narrow catching to IllegalArgumentException (with allowing the possibility that some other valid case might not be handled atm), but then I would like to add this to official contract of the TokenPairSerializer interface as a throws. Catching implementation specific exceptions is not a best design decision.

Comment thread
Praveen2112 marked this conversation as resolved.
Outdated
log.debug(ex, "Failed to deserialize token");
return Optional.empty();
}
}

@Override
protected AuthenticationException needAuthentication(ContainerRequestContext request, Optional<String> currentToken, String message)
{
return currentToken
.map(tokenPairSerializer::deserialize)
.flatMap(this::deserializeToken)
.flatMap(tokenRefresher::refreshToken)
.map(refreshId -> request.getUriInfo().getBaseUri().resolve(getTokenUri(refreshId)))
.map(tokenUri -> new AuthenticationException(message, format("Bearer x_token_server=\"%s\"", tokenUri)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,48 @@ public void testJwtAndOAuth2AuthenticatorsSeparation()
}
}

@Test
public void testJwtWithRefreshTokensForOAuth2Enabled()
throws Exception
{
TestingHttpServer jwkServer = createTestingJwkServer();
jwkServer.start();
try (TokenServer tokenServer = new TokenServer(Optional.empty());
TestingTrinoServer server = TestingTrinoServer.builder()
.setProperties(
ImmutableMap.<String, String>builder()
.putAll(SECURE_PROPERTIES)
.put("http-server.authentication.type", "oauth2,jwt")
.put("http-server.authentication.jwt.key-file", jwkServer.getBaseUrl().toString())
.putAll(ImmutableMap.<String, String>builder()
.putAll(getOAuth2Properties(tokenServer))
.put("http-server.authentication.oauth2.refresh-tokens", "true")
.buildOrThrow())
.put("web-ui.enabled", "true")
.buildOrThrow())
.setAdditionalModule(oauth2Module(tokenServer))
.build()) {
server.getInstance(Key.get(AccessControlManager.class)).addSystemAccessControl(TestSystemAccessControl.NO_IMPERSONATION);
HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class));

assertAuthenticationDisabled(httpServerInfo.getHttpUri());

String token = newJwtBuilder()
.signWith(JWK_PRIVATE_KEY)
.setHeaderParam(JwsHeader.KEY_ID, JWK_KEY_ID)
.setSubject("test-user")
.setExpiration(Date.from(ZonedDateTime.now().plusMinutes(5).toInstant()))
.compact();

OkHttpClient clientWithJwt = client.newBuilder()
.authenticator((route, response) -> response.request().newBuilder()
.header(AUTHORIZATION, "Bearer " + token)
.build())
.build();
assertAuthenticationAutomatic(httpServerInfo.getHttpsUri(), clientWithJwt);
}
}

private static Module oauth2Module(TokenServer tokenServer)
{
return binder -> {
Expand Down