diff --git a/presto-docs/src/main/sphinx/security/oauth2.rst b/presto-docs/src/main/sphinx/security/oauth2.rst index 65b676c42d394..838c499e883f9 100644 --- a/presto-docs/src/main/sphinx/security/oauth2.rst +++ b/presto-docs/src/main/sphinx/security/oauth2.rst @@ -38,6 +38,8 @@ Below are the key configuration properties for enabling OAuth2 authentication in http-server.authentication.oauth2.state-key=your-hmac-secret http-server.authentication.oauth2.additional-audiences=your-client-id,another-audience http-server.authentication.oauth2.user-mapping.pattern=(.*) + http-server.authentication.oauth2.userinfo-cache=false + http-server.authentication.oauth2.userinfo-cache-ttl=10m It is worth noting that ``configuration-based-authorizer.role-regex-map.file-path`` must be configured if authentication type is set to ``OAUTH2``. @@ -58,11 +60,13 @@ If your IdP uses a custom or self-signed certificate, import it into the Java tr Notes ----- -- **Issuer**: The base URL of your IdP’s OIDC discovery endpoint. +- **Issuer**: The base URL of your IdP's OIDC discovery endpoint. - **Client ID/Secret**: Registered credentials for Presto in your IdP. - **Scopes**: Must include ``openid``; others like ``email``, ``profile``, or ``groups`` are optional. -- **Principal Field**: The claim in the ID token used as the Presto username. +- **Principal Field**: The claim used as the Presto username. For OIDC flows (when ``openid`` scope is included), this is extracted from the ID token. If the claim is not present in the ID token, Presto will query the UserInfo endpoint as a fallback. For pure OAuth2 flows (without ``openid`` scope), the UserInfo endpoint is queried first, with the access token as a last resort. - **Groups Field**: Optional claim used for role-based access control. - **State Key**: A secret used to sign the OAuth2 state parameter (HMAC). - **Refresh Tokens**: Enable if your IdP supports issuing refresh tokens. +- **UserInfo Cache**: Enable caching of UserInfo endpoint responses to reduce load on the IdP and improve performance. When enabled, responses are cached using a SHA-256 hash of the access token as the key. Default is ``false``. +- **UserInfo Cache TTL**: Time-to-live for cached UserInfo entries. Only applicable when ``userinfo-cache`` is enabled. Default is ``10m`` (10 minutes). Minimum value is ``1m``. - **Callback**: When configuring your IdP the callback URI must be set to ``[presto]/oauth2/callback`` diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializer.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializer.java index 45ded8e0ed84c..08f2526592da7 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializer.java @@ -138,17 +138,26 @@ public String serialize(TokenPair tokenPair) { requireNonNull(tokenPair, "tokenPair is null"); - Optional> accessTokenClaims = client.getClaims(tokenPair.getAccessToken()); - if (!accessTokenClaims.isPresent()) { - throw new IllegalArgumentException("Claims are missing"); + // Try to get claims from the TokenPair first (from ID token or UserInfo) + // This is the correct source per OIDC specification + Optional> claims = tokenPair.getClaims(); + + // Fallback to access token claims for backward compatibility + if (!claims.isPresent()) { + claims = client.getClaims(tokenPair.getAccessToken()); + if (!claims.isPresent()) { + throw new IllegalArgumentException("Claims are missing from both ID token/UserInfo and access token"); + } } - Map claims = accessTokenClaims.get(); - if (!claims.containsKey(principalField)) { - throw new IllegalArgumentException(format("%s field is missing", principalField)); + + Map claimsMap = claims.get(); + if (!claimsMap.containsKey(principalField)) { + throw new IllegalArgumentException(format("%s field is missing from claims", principalField)); } + JwtBuilder jwt = newJwtBuilder() .setExpiration(Date.from(clock.instant().plusMillis(tokenExpiration.toMillis()))) - .claim(principalField, claims.get(principalField).toString()) + .claim(principalField, claimsMap.get(principalField).toString()) .setAudience(audience) .setIssuer(issuer) .claim(ACCESS_TOKEN_KEY, tokenPair.getAccessToken()) diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NimbusOAuth2Client.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NimbusOAuth2Client.java index 86b2fb96b7d8f..d2e73ac7c65cd 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NimbusOAuth2Client.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NimbusOAuth2Client.java @@ -16,6 +16,8 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.units.Duration; import com.facebook.presto.server.security.oauth2.OAuth2ServerConfigProvider.OAuth2ServerConfig; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Ordering; import com.nimbusds.jose.JOSEException; @@ -93,7 +95,6 @@ public class NimbusOAuth2Client implements OAuth2Client { private static final Logger LOG = Logger.get(NimbusAirliftHttpClient.class); - private final Issuer issuer; private final ClientID clientId; private final ClientSecretBasic clientAuth; @@ -111,10 +112,16 @@ public class NimbusOAuth2Client private JWTProcessor accessTokenProcessor; private AuthorizationCodeFlow flow; + // Cache for UserInfo endpoint responses (initialized conditionally based on config) + // Key: SHA-256 hash of access token, Value: UserInfo claims + private final Cache userInfoCache; + private final boolean userinfoCacheEnabled; + @Inject public NimbusOAuth2Client(OAuth2Config oauthConfig, OAuth2ServerConfigProvider serverConfigurationProvider, NimbusHttpClient httpClient) { requireNonNull(oauthConfig, "oauthConfig is null"); + issuer = new Issuer(oauthConfig.getIssuer()); clientId = new ClientID(oauthConfig.getClientId()); clientAuth = new ClientSecretBasic(clientId, new Secret(oauthConfig.getClientSecret())); @@ -128,6 +135,21 @@ public NimbusOAuth2Client(OAuth2Config oauthConfig, OAuth2ServerConfigProvider s this.serverConfigurationProvider = requireNonNull(serverConfigurationProvider, "serverConfigurationProvider is null"); this.httpClient = requireNonNull(httpClient, "httpClient is null"); + + // Initialize UserInfo cache based on configuration + this.userinfoCacheEnabled = oauthConfig.isUserinfoCacheEnabled(); + if (this.userinfoCacheEnabled) { + Duration cacheTtl = oauthConfig.getUserinfoCacheTtl(); + this.userInfoCache = CacheBuilder.newBuilder() + .maximumSize(1000) // Max 1k entries + .expireAfterWrite(cacheTtl.toMillis(), java.util.concurrent.TimeUnit.MILLISECONDS) + .build(); + LOG.info("UserInfo cache enabled with TTL of %s", cacheTtl); + } + else { + this.userInfoCache = null; + LOG.info("UserInfo cache disabled"); + } } @Override @@ -148,12 +170,14 @@ public void load() DefaultJWTProcessor processor = new DefaultJWTProcessor<>(); processor.setJWSKeySelector(jwsKeySelector); + // user specified principal claim may not be a required claim in the access token + // manually validate it with the response of the userinfo endpoint DefaultJWTClaimsVerifier accessTokenVerifier = new DefaultJWTClaimsVerifier<>( accessTokenAudiences, new JWTClaimsSet.Builder() .issuer(config.getAccessTokenIssuer().orElse(issuer.getValue())) .build(), - ImmutableSet.of(principalField), + ImmutableSet.of(), ImmutableSet.of()); accessTokenVerifier.setMaxClockSkew((int) maxClockSkew.roundTo(SECONDS)); processor.setJWTClaimsSetVerifier(accessTokenVerifier); @@ -244,11 +268,44 @@ private Response toResponse(Tokens tokens, Optional existingRefreshToken { AccessToken accessToken = tokens.getAccessToken(); RefreshToken refreshToken = tokens.getRefreshToken(); - JWTClaimsSet claims = getJWTClaimsSet(accessToken.getValue()).orElseThrow(() -> new ChallengeFailedException("invalid access token")); + + // For pure OAuth2 (no ID token), try UserInfo first, then access token + JWTClaimsSet claims = getUserClaims(accessToken.getValue()) + .orElseThrow(() -> new ChallengeFailedException("Cannot retrieve user claims")); + return new Response( accessToken.getValue(), determineExpiration(getExpiration(accessToken), claims.getExpirationTime()), - buildRefreshToken(refreshToken, existingRefreshToken)); + buildRefreshToken(refreshToken, existingRefreshToken), + claims.getClaims()); + } + + /** + * Retrieves user claims for pure OAuth2 flow (no ID token). + * Tries UserInfo endpoint first, then falls back to access token parsing. + * + * @param accessToken the access token + * @return Optional containing claims + */ + private Optional getUserClaims(String accessToken) + { + // Try UserInfo endpoint first (preferred for OAuth2) + if (userinfoUrl.isPresent()) { + Optional userInfoClaims = queryUserInfo(accessToken); + if (userInfoClaims.isPresent() && userInfoClaims.get().getClaim(principalField) != null) { + return userInfoClaims; + } + } + + // Fallback: try parsing access token (only if it's a JWT) + Optional accessTokenClaims = parseAccessToken(accessToken); + if (accessTokenClaims.isPresent() && accessTokenClaims.get().getClaim(principalField) != null) { + LOG.warn("Using access token for principal extraction - this is not recommended. Consider using OIDC with ID tokens."); + return accessTokenClaims; + } + + LOG.error("Cannot find principal field '%s' in UserInfo or access token", principalField); + return Optional.empty(); } } @@ -306,11 +363,42 @@ private Response toResponse(OIDCTokens tokens, Optional existingRefreshT { AccessToken accessToken = tokens.getAccessToken(); RefreshToken refreshToken = tokens.getRefreshToken(); - JWTClaimsSet claims = getJWTClaimsSet(accessToken.getValue()).orElseThrow(() -> new ChallengeFailedException("invalid access token")); + + // FIXED: Get claims from ID token (primary source per OIDC spec) + JWTClaimsSet claims = getJWTClaimsSetFromIdToken(tokens.getIDToken()) + .orElseThrow(() -> new ChallengeFailedException("invalid ID token")); + + // Fallback to UserInfo if principal field not in ID token + if (claims.getClaim(principalField) == null) { + LOG.debug("Principal field '%s' not found in ID token, querying UserInfo endpoint", principalField); + claims = queryUserInfo(accessToken.getValue()) + .orElseThrow(() -> new ChallengeFailedException( + String.format("Principal field '%s' not found in ID token or UserInfo", principalField))); + } + return new Response( accessToken.getValue(), determineExpiration(getExpiration(accessToken), claims.getExpirationTime()), - buildRefreshToken(refreshToken, existingRefreshToken)); + buildRefreshToken(refreshToken, existingRefreshToken), + claims.getClaims()); + } + + /** + * Extracts JWT claims from a validated ID token. + * This is the correct source for user identity per OIDC specification. + * + * @param idToken the validated ID token + * @return Optional containing claims if successful + */ + private Optional getJWTClaimsSetFromIdToken(com.nimbusds.jwt.JWT idToken) + { + try { + return Optional.of(idToken.getJWTClaimsSet()); + } + catch (java.text.ParseException e) { + LOG.error(e, "Failed to parse ID token claims"); + return Optional.empty(); + } } private void validateTokens(OIDCTokens tokens, Optional nonce) @@ -368,39 +456,142 @@ private T getTokenResponse(AuthorizationGrant au return tokenResponse; } + /** + * Retrieves JWT claims for the given access token. + * + * IMPORTANT: This method should NOT be used for extracting the principal field. + * Per OIDC specification, the principal should come from the ID token, not the access token. + * This method is kept for backward compatibility and the getClaims() API method. + * + * @param accessToken the access token value + * @return Optional containing claims from access token or UserInfo endpoint + */ private Optional getJWTClaimsSet(String accessToken) { + // Try parsing access token as JWT + Optional claims = parseAccessToken(accessToken); + if (claims.isPresent()) { + return claims; + } + + // Fallback to UserInfo endpoint if (userinfoUrl.isPresent()) { return queryUserInfo(accessToken); } - return parseAccessToken(accessToken); + + return Optional.empty(); } + /** + * Queries the UserInfo endpoint to retrieve user claims. + * This should be used as a fallback when the principal field is not in the ID token, + * or for pure OAuth2 flows without ID tokens. + * + * Results are cached if caching is enabled to reduce redundant API calls. + * + * @param accessToken the OAuth2 access token for authentication + * @return Optional containing JWTClaimsSet if successful, empty otherwise + */ private Optional queryUserInfo(String accessToken) { + // Validate input + if (accessToken == null || accessToken.trim().isEmpty()) { + LOG.error("Invalid access token provided to queryUserInfo"); + return Optional.empty(); + } + + if (!userinfoUrl.isPresent()) { + LOG.debug("UserInfo URL not configured, cannot query user information"); + return Optional.empty(); + } + + String cacheKey = null; + + // Check cache if enabled + if (userinfoCacheEnabled) { + cacheKey = computeCacheKey(accessToken); + + JWTClaimsSet cachedClaims = userInfoCache.getIfPresent(cacheKey); + if (cachedClaims != null) { + LOG.debug("UserInfo cache hit for token hash: %s", cacheKey.substring(0, 8)); + return Optional.of(cachedClaims); + } + + LOG.debug("UserInfo cache miss for token hash: %s", cacheKey.substring(0, 8)); + } + + // Query UserInfo endpoint try { - UserInfoResponse response = httpClient.execute(new UserInfoRequest(userinfoUrl.get(), new BearerAccessToken(accessToken)), this::parse); - if (!response.indicatesSuccess()) { - LOG.error("Received bad response from userinfo endpoint: " + response.toErrorResponse().getErrorObject()); - return Optional.empty(); + JWTClaimsSet claims = fetchUserInfoClaims(accessToken); + + // Store in cache if enabled + if (userinfoCacheEnabled && cacheKey != null) { + userInfoCache.put(cacheKey, claims); + LOG.debug("UserInfo cached for token hash: %s", cacheKey.substring(0, 8)); } - return Optional.of(response.toSuccessResponse().getUserInfo().toJWTClaimsSet()); + + return Optional.of(claims); } - catch (ParseException | RuntimeException e) { - LOG.error(e, "Received bad response from userinfo endpoint"); + catch (ParseException e) { + LOG.error(e, "Failed to parse UserInfo response from %s", userinfoUrl.get()); + return Optional.empty(); + } + catch (RuntimeException e) { + LOG.error(e, "Failed to query UserInfo endpoint %s", userinfoUrl.get()); return Optional.empty(); } } + /** + * Computes a cache key from the access token using SHA-256 hashing. + * This prevents storing sensitive tokens directly in the cache. + * + * @param accessToken the access token to hash + * @return SHA-256 hash of the token as a hex string + */ + private String computeCacheKey(String accessToken) + { + return sha256() + .hashString(accessToken, UTF_8) + .toString(); + } + + /** + * Fetches user information claims from the UserInfo endpoint. + * + * @param accessToken the OAuth2 access token for authentication + * @return JWTClaimsSet containing user information + * @throws ParseException if the response cannot be parsed + * @throws RuntimeException if the HTTP request fails + */ + private JWTClaimsSet fetchUserInfoClaims(String accessToken) throws ParseException + { + UserInfoResponse response = httpClient.execute( + new UserInfoRequest(userinfoUrl.get(), new BearerAccessToken(accessToken)), + this::parse); + + if (!response.indicatesSuccess()) { + UserInfoErrorResponse errorResponse = response.toErrorResponse(); + LOG.error("Received error from UserInfo endpoint: %s", errorResponse.getErrorObject()); + throw new RuntimeException("UserInfo endpoint returned error: " + errorResponse.getErrorObject()); + } + + return response.toSuccessResponse().getUserInfo().toJWTClaimsSet(); + } + // Using this parsing method for our /userinfo response from the IdP in order to allow for different principal // fields as defined, and in the absence of the `sub` claim. This is a "hack" solution to alter the claims // present in the response before calling the parser provided by the oidc sdk, which fails hard if the - // `sub` claim is missing. Note we also have to offload audience verification to this method since it - // is not handled in the library + // `sub` claim is missing. public UserInfoResponse parse(HTTPResponse httpResponse) throws ParseException { - JSONObject body = httpResponse.getContentAsJSONObject(); + // Check status code first and only process payload if successful + if (httpResponse.getStatusCode() != 200) { + return UserInfoErrorResponse.parse(httpResponse); + } + + JSONObject body = httpResponse.getBodyAsJSONObject(); String principal = (String) body.get(principalField); if (principal == null) { @@ -413,28 +604,29 @@ public UserInfoResponse parse(HTTPResponse httpResponse) } Object audClaim = body.get("aud"); - List audiences; + // only validate aud claim if it exists + if (audClaim != null) { + List audiences; - if (audClaim instanceof String) { - audiences = List.of((String) audClaim); - } - else if (audClaim instanceof List) { - audiences = ((List) audClaim).stream() - .filter(String.class::isInstance) - .map(String.class::cast) - .collect(toImmutableList()); - } - else { - throw new ParseException("Unsupported or missing 'aud' claim type in /userinfo response"); - } + if (audClaim instanceof String) { + audiences = List.of((String) audClaim); + } + else if (audClaim instanceof List) { + audiences = ((List) audClaim).stream() + .filter(String.class::isInstance) + .map(String.class::cast) + .collect(toImmutableList()); + } + else { + throw new ParseException("Unsupported 'aud' claim type in /userinfo response"); + } - if (!(audiences.contains(clientId.getValue()) || !Collections.disjoint(audiences, accessTokenAudiences))) { - throw new ParseException("Invalid audience in /userinfo response"); + if (!audiences.contains(clientId.getValue()) && Collections.disjoint(audiences, accessTokenAudiences)) { + throw new ParseException("Invalid audience in /userinfo response"); + } } - return (httpResponse.getStatusCode() == 200) - ? UserInfoSuccessResponse.parse(httpResponse) - : UserInfoErrorResponse.parse(httpResponse); + return UserInfoSuccessResponse.parse(httpResponse); } private Optional parseAccessToken(String accessToken) @@ -443,7 +635,7 @@ private Optional parseAccessToken(String accessToken) return Optional.of(accessTokenProcessor.process(accessToken, null)); } catch (java.text.ParseException | BadJOSEException | JOSEException e) { - LOG.error(e, "Failed to parse JWT access token"); + LOG.debug(e, "Failed to parse JWT access token"); return Optional.empty(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Authenticator.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Authenticator.java index 83c2de302c287..b752b16e6089f 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Authenticator.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Authenticator.java @@ -75,10 +75,17 @@ public Principal authenticate(HttpServletRequest request) throws AuthenticationE if (tokenPair.getExpiration().before(Date.from(Instant.now()))) { throw needAuthentication(request, Optional.of(token), "Invalid Credentials"); } - Optional> claims = client.getClaims(tokenPair.getAccessToken()); + // Try to get claims from TokenPair first (from ID token or UserInfo) + // This is the correct source per OIDC specification + Optional> claims = tokenPair.getClaims(); + + // Fallback to access token claims for backward compatibility if (!claims.isPresent()) { - throw needAuthentication(request, Optional.ofNullable(token), "Invalid Credentials"); + claims = client.getClaims(tokenPair.getAccessToken()); + if (!claims.isPresent()) { + throw needAuthentication(request, Optional.ofNullable(token), "Invalid Credentials"); + } } String principal = (String) claims.get().get(principalField); if (StringUtils.isEmpty(principal)) { diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Client.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Client.java index fd68a2a5a2e06..7e9324381b8ad 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Client.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Client.java @@ -60,14 +60,20 @@ class Response { private final String accessToken; private final Instant expiration; - private final Optional refreshToken; + private final Map claims; public Response(String accessToken, Instant expiration, Optional refreshToken) + { + this(accessToken, expiration, refreshToken, null); + } + + public Response(String accessToken, Instant expiration, Optional refreshToken, Map claims) { this.accessToken = requireNonNull(accessToken, "accessToken is null"); this.expiration = requireNonNull(expiration, "expiration is null"); this.refreshToken = requireNonNull(refreshToken, "refreshToken is null"); + this.claims = claims; } public String getAccessToken() @@ -84,5 +90,16 @@ public Optional getRefreshToken() { return refreshToken; } + + /** + * Returns the user claims from ID token (for OIDC) or UserInfo endpoint (for OAuth2). + * These claims should be used for extracting the principal field, not the access token. + * + * @return Optional containing claims map, or empty if not available + */ + public Optional> getClaims() + { + return Optional.ofNullable(claims); + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Config.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Config.java index b1b0f9513b2f7..fe72cbdc0d847 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Config.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Config.java @@ -49,6 +49,8 @@ public class OAuth2Config private Optional userMappingFile = Optional.empty(); private boolean enableRefreshTokens; private boolean enableDiscovery = true; + private boolean userinfoCacheEnabled; + private Duration userinfoCacheTtl = new Duration(10, TimeUnit.MINUTES); public Optional getStateKey() { @@ -250,4 +252,32 @@ public OAuth2Config setEnableDiscovery(boolean enableDiscovery) this.enableDiscovery = enableDiscovery; return this; } + + public boolean isUserinfoCacheEnabled() + { + return userinfoCacheEnabled; + } + + @Config("http-server.authentication.oauth2.userinfo-cache") + @ConfigDescription("Enable caching of userinfo endpoint responses") + public OAuth2Config setUserinfoCacheEnabled(boolean userinfoCacheEnabled) + { + this.userinfoCacheEnabled = userinfoCacheEnabled; + return this; + } + + @MinDuration("1m") + @NotNull + public Duration getUserinfoCacheTtl() + { + return userinfoCacheTtl; + } + + @Config("http-server.authentication.oauth2.userinfo-cache-ttl") + @ConfigDescription("TTL for userinfo cache entries") + public OAuth2Config setUserinfoCacheTtl(Duration userinfoCacheTtl) + { + this.userinfoCacheTtl = userinfoCacheTtl; + return this; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2ServiceModule.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2ServiceModule.java index 31b8740edbaf4..1fd47b9433b11 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2ServiceModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2ServiceModule.java @@ -27,7 +27,7 @@ import static com.facebook.airlift.configuration.ConfigBinder.configBinder; import static com.facebook.airlift.http.client.HttpClientBinder.httpClientBinder; import static com.facebook.airlift.jaxrs.JaxrsBinder.jaxrsBinder; -import static com.facebook.presto.server.security.oauth2.TokenPairSerializer.ACCESS_TOKEN_ONLY_SERIALIZER; +import static com.facebook.presto.server.security.oauth2.TokenPairSerializer.ACCESS_TOKEN_CLAIMS_ONLY_SERIALIZER; import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; public class OAuth2ServiceModule @@ -67,7 +67,7 @@ private void enableRefreshTokens(Binder binder) private void disableRefreshTokens(Binder binder) { - binder.bind(TokenPairSerializer.class).toInstance(ACCESS_TOKEN_ONLY_SERIALIZER); + binder.bind(TokenPairSerializer.class).toInstance(ACCESS_TOKEN_CLAIMS_ONLY_SERIALIZER); newOptionalBinder(binder, Key.get(Duration.class, ForRefreshTokens.class)); } diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenPairSerializer.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenPairSerializer.java index a92e7fdf00e06..62c41f2e85d54 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenPairSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenPairSerializer.java @@ -14,28 +14,82 @@ package com.facebook.presto.server.security.oauth2; +import com.facebook.airlift.log.Logger; import com.facebook.presto.server.security.oauth2.OAuth2Client.Response; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import jakarta.annotation.Nullable; +import java.io.IOException; +import java.util.Base64; import java.util.Date; +import java.util.HashMap; +import java.util.Map; import java.util.Optional; import static java.lang.Long.MAX_VALUE; +import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; public interface TokenPairSerializer { - TokenPairSerializer ACCESS_TOKEN_ONLY_SERIALIZER = new TokenPairSerializer() + /** + * Serializer that stores access token and claims in a simple JSON format. + * Used when refresh tokens are disabled but we still need to preserve claims for authentication. + */ + TokenPairSerializer ACCESS_TOKEN_CLAIMS_ONLY_SERIALIZER = new TokenPairSerializer() { + private static final Logger LOG = Logger.get(TokenPairSerializer.class); + private final ObjectMapper objectMapper = new ObjectMapper(); + @Override public TokenPair deserialize(String token) { + // Try to decode from Base64 and parse as JSON (new format with claims) + try { + // First try to decode from Base64 + byte[] decodedBytes = Base64.getDecoder().decode(token); + String decodedJson = new String(decodedBytes, UTF_8); + + @SuppressWarnings("unchecked") + Map data = objectMapper.readValue(decodedJson, Map.class); + if (data.containsKey("accessToken") && data.containsKey("claims")) { + String accessToken = (String) data.get("accessToken"); + @SuppressWarnings("unchecked") + Map claims = (Map) data.get("claims"); + LOG.debug("Deserialized token with claims from new Base64-encoded JSON format"); + return new TokenPair(accessToken, new Date(MAX_VALUE), Optional.empty(), Optional.of(claims)); + } + } + catch (IllegalArgumentException | IOException e) { + // Not Base64-encoded JSON, treat as plain access token (backward compatibility) + LOG.debug("Token is not in new Base64-encoded JSON format, treating as plain access token (old format or migration in progress)"); + } + + // Fallback: treat as plain access token (backward compatibility) + LOG.debug("Using plain access token format (no claims available, will need to query on authentication)"); return TokenPair.accessToken(token); } @Override public String serialize(TokenPair tokenPair) { + // If claims are present, serialize as Base64-encoded JSON with access token and claims + if (tokenPair.getClaims().isPresent()) { + try { + Map data = new HashMap<>(); + data.put("accessToken", tokenPair.getAccessToken()); + data.put("claims", tokenPair.getClaims().get()); + String json = objectMapper.writeValueAsString(data); + // Base64 encode to ensure cookie compatibility + return Base64.getEncoder().encodeToString(json.getBytes(UTF_8)); + } + catch (JsonProcessingException e) { + throw new IllegalStateException("Failed to serialize token with claims", e); + } + } + + // Fallback: serialize as plain access token (backward compatibility) return tokenPair.getAccessToken(); } }; @@ -49,12 +103,19 @@ class TokenPair private final String accessToken; private final Date expiration; private final Optional refreshToken; + private final Optional> claims; private TokenPair(String accessToken, Date expiration, Optional refreshToken) + { + this(accessToken, expiration, refreshToken, Optional.empty()); + } + + private TokenPair(String accessToken, Date expiration, Optional refreshToken, Optional> claims) { this.accessToken = requireNonNull(accessToken, "accessToken is nul"); this.expiration = requireNonNull(expiration, "expiration is null"); this.refreshToken = requireNonNull(refreshToken, "refreshToken is null"); + this.claims = requireNonNull(claims, "claims is null"); } public static TokenPair accessToken(String accessToken) @@ -65,7 +126,11 @@ public static TokenPair accessToken(String accessToken) public static TokenPair fromOAuth2Response(Response tokens) { requireNonNull(tokens, "tokens is null"); - return new TokenPair(tokens.getAccessToken(), Date.from(tokens.getExpiration()), tokens.getRefreshToken()); + return new TokenPair( + tokens.getAccessToken(), + Date.from(tokens.getExpiration()), + tokens.getRefreshToken(), + tokens.getClaims()); } public static TokenPair accessAndRefreshTokens(String accessToken, Date expiration, @Nullable String refreshToken) @@ -88,6 +153,17 @@ public Optional getRefreshToken() return refreshToken; } + /** + * Returns the user claims from ID token (for OIDC) or UserInfo endpoint (for OAuth2). + * These claims should be used for extracting the principal field, not the access token. + * + * @return Optional containing claims map, or empty if not available + */ + public Optional> getClaims() + { + return claims; + } + public static TokenPair withAccessAndRefreshTokens(String accessToken, Date expiration, @Nullable String refreshToken) { return new TokenPair(accessToken, expiration, Optional.ofNullable(refreshToken)); diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/BaseOAuth2AuthenticationFilterTest.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/BaseOAuth2AuthenticationFilterTest.java index 4540c655cb655..ba7b8507db737 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/BaseOAuth2AuthenticationFilterTest.java +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/BaseOAuth2AuthenticationFilterTest.java @@ -316,7 +316,9 @@ private void assertOAuth2Cookie(HttpCookie cookie) protected void validateAccessToken(String cookieValue) { - Request request = new Request.Builder().url("https://localhost:" + hydraIdP.getAuthPort() + "/userinfo").addHeader(AUTHORIZATION, "Bearer " + cookieValue).build(); + // Extract access token from cookie value in base64-encoded JSON format + String accessToken = extractAccessToken(cookieValue); + Request request = new Request.Builder().url("https://localhost:" + hydraIdP.getAuthPort() + "/userinfo").addHeader(AUTHORIZATION, "Bearer " + accessToken).build(); try (Response response = httpClient.newCall(request).execute()) { assertThat(response.body()).isNotNull(); DefaultClaims claims = new DefaultClaims(JsonCodec.mapJsonCodec(String.class, Object.class).fromJson(response.body().bytes())); @@ -327,6 +329,19 @@ protected void validateAccessToken(String cookieValue) } } + private String extractAccessToken(String cookieValue) + { + // Decode Base64-encoded JSON to extract access token + byte[] decodedBytes = java.util.Base64.getDecoder().decode(cookieValue); + String decodedJson = new String(decodedBytes, java.nio.charset.StandardCharsets.UTF_8); + java.util.Map data = JsonCodec.mapJsonCodec(String.class, Object.class).fromJson(decodedJson); + String accessToken = (String) data.get("accessToken"); + if (accessToken == null) { + throw new IllegalStateException("Cookie value does not contain 'accessToken' field: " + decodedJson); + } + return accessToken; + } + private void assertUICallWithCookie(String cookieValue) throws IOException { diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestNimbusOAuth2ClientUserInfoParser.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestNimbusOAuth2ClientUserInfoParser.java new file mode 100644 index 0000000000000..fa86ec14886dc --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestNimbusOAuth2ClientUserInfoParser.java @@ -0,0 +1,377 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.http.client.HttpClient; +import com.facebook.airlift.http.client.testing.TestingHttpClient; +import com.facebook.airlift.units.Duration; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.http.HTTPResponse; +import com.nimbusds.openid.connect.sdk.UserInfoErrorResponse; +import com.nimbusds.openid.connect.sdk.UserInfoResponse; +import com.nimbusds.openid.connect.sdk.UserInfoSuccessResponse; +import net.minidev.json.JSONArray; +import net.minidev.json.JSONObject; +import org.testng.annotations.Test; + +import java.net.URI; +import java.util.Optional; + +import static java.util.concurrent.TimeUnit.MINUTES; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestNimbusOAuth2ClientUserInfoParser +{ + private static final String CLIENT_ID = "test-client-id"; + private static final String ADDITIONAL_AUDIENCE = "additional-audience"; + + private NimbusOAuth2Client createClient(String principalField, String additionalAudiences) + { + OAuth2Config config = new OAuth2Config() + .setIssuer("https://issuer.example.com") + .setClientId(CLIENT_ID) + .setClientSecret("test-secret") + .setPrincipalField(principalField) + .setAdditionalAudiences(additionalAudiences) + .setMaxClockSkew(new Duration(1, MINUTES)); + + OAuth2ServerConfigProvider configProvider = new OAuth2ServerConfigProvider() + { + @Override + public OAuth2ServerConfig get() + { + return new OAuth2ServerConfig( + Optional.empty(), + URI.create("https://issuer.example.com/auth"), + URI.create("https://issuer.example.com/token"), + URI.create("https://issuer.example.com/jwks"), + Optional.empty()); + } + }; + + HttpClient httpClient = new TestingHttpClient(request -> { + throw new UnsupportedOperationException("HTTP client should not be called in these tests"); + }); + NimbusHttpClient nimbusHttpClient = new NimbusAirliftHttpClient(httpClient); + + return new NimbusOAuth2Client(config, configProvider, nimbusHttpClient); + } + + @Test + public void testParseSuccessWithStandardSubClaim() + throws Exception + { + NimbusOAuth2Client client = createClient("sub", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("sub", "user@example.com"); + body.put("email", "user@example.com"); + httpResponse.setBody(body.toJSONString()); + + UserInfoResponse response = client.parse(httpResponse); + + assertThat(response.indicatesSuccess()).isTrue(); + UserInfoSuccessResponse successResponse = response.toSuccessResponse(); + assertThat(successResponse.getUserInfo().getSubject().getValue()).isEqualTo("user@example.com"); + } + + @Test + public void testParseSuccessWithCustomPrincipalField() + throws Exception + { + NimbusOAuth2Client client = createClient("email", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + httpResponse.setContentType("application/json"); + JSONObject body = new JSONObject(); + body.put("email", "custom@example.com"); + body.put("name", "Custom User"); + httpResponse.setBody(body.toJSONString()); + + UserInfoResponse response = client.parse(httpResponse); + + assertThat(response.indicatesSuccess()).isTrue(); + UserInfoSuccessResponse successResponse = response.toSuccessResponse(); + // The parser should add "sub" claim with the value from "email" + assertThat(successResponse.getUserInfo().getSubject().getValue()).isEqualTo("custom@example.com"); + } + + @Test + public void testParseSuccessWithCustomPrincipalFieldAndExistingSub() + throws Exception + { + NimbusOAuth2Client client = createClient("email", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("sub", "original-sub"); + body.put("email", "custom@example.com"); + httpResponse.setBody(body.toJSONString()); + + UserInfoResponse response = client.parse(httpResponse); + + assertThat(response.indicatesSuccess()).isTrue(); + UserInfoSuccessResponse successResponse = response.toSuccessResponse(); + // Should keep the original "sub" claim + assertThat(successResponse.getUserInfo().getSubject().getValue()).isEqualTo("original-sub"); + } + + @Test + public void testParseMissingPrincipalField() + throws Exception + { + NimbusOAuth2Client client = createClient("sub", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("email", "user@example.com"); + // Missing "sub" field + httpResponse.setBody(body.toJSONString()); + + assertThatThrownBy(() -> client.parse(httpResponse)) + .isInstanceOf(ParseException.class) + .hasMessageContaining("/userinfo response missing principal field sub"); + } + + @Test + public void testParseMissingCustomPrincipalField() + throws Exception + { + NimbusOAuth2Client client = createClient("custom_field", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("sub", "user@example.com"); + // Missing "custom_field" + httpResponse.setBody(body.toJSONString()); + + assertThatThrownBy(() -> client.parse(httpResponse)) + .isInstanceOf(ParseException.class) + .hasMessageContaining("/userinfo response missing principal field custom_field"); + } + + @Test + public void testParseWithValidAudienceString() + throws Exception + { + NimbusOAuth2Client client = createClient("sub", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("sub", "user@example.com"); + body.put("aud", CLIENT_ID); + httpResponse.setBody(body.toJSONString()); + + UserInfoResponse response = client.parse(httpResponse); + + assertThat(response.indicatesSuccess()).isTrue(); + } + + @Test + public void testParseWithValidAudienceArray() + throws Exception + { + NimbusOAuth2Client client = createClient("sub", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("sub", "user@example.com"); + JSONArray audiences = new JSONArray(); + audiences.add(CLIENT_ID); + audiences.add("other-audience"); + body.put("aud", audiences); + httpResponse.setBody(body.toJSONString()); + + UserInfoResponse response = client.parse(httpResponse); + + assertThat(response.indicatesSuccess()).isTrue(); + } + + @Test + public void testParseWithValidAdditionalAudience() + throws Exception + { + NimbusOAuth2Client client = createClient("sub", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("sub", "user@example.com"); + body.put("aud", ADDITIONAL_AUDIENCE); + httpResponse.setBody(body.toJSONString()); + + UserInfoResponse response = client.parse(httpResponse); + + assertThat(response.indicatesSuccess()).isTrue(); + } + + @Test + public void testParseWithInvalidAudienceString() + throws Exception + { + NimbusOAuth2Client client = createClient("sub", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("sub", "user@example.com"); + body.put("aud", "invalid-audience"); + httpResponse.setBody(body.toJSONString()); + + assertThatThrownBy(() -> client.parse(httpResponse)) + .isInstanceOf(ParseException.class) + .hasMessageContaining("Invalid audience in /userinfo response"); + } + + @Test + public void testParseWithInvalidAudienceArray() + throws Exception + { + NimbusOAuth2Client client = createClient("sub", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("sub", "user@example.com"); + JSONArray audiences = new JSONArray(); + audiences.add("invalid-audience-1"); + audiences.add("invalid-audience-2"); + body.put("aud", audiences); + httpResponse.setBody(body.toJSONString()); + + assertThatThrownBy(() -> client.parse(httpResponse)) + .isInstanceOf(ParseException.class) + .hasMessageContaining("Invalid audience in /userinfo response"); + } + + @Test + public void testParseWithMixedAudienceArrayContainingValid() + throws Exception + { + NimbusOAuth2Client client = createClient("sub", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("sub", "user@example.com"); + JSONArray audiences = new JSONArray(); + audiences.add("invalid-audience"); + audiences.add(CLIENT_ID); + body.put("aud", audiences); + httpResponse.setBody(body.toJSONString()); + + UserInfoResponse response = client.parse(httpResponse); + + assertThat(response.indicatesSuccess()).isTrue(); + } + + @Test + public void testParseWithUnsupportedAudienceType() + throws Exception + { + NimbusOAuth2Client client = createClient("sub", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("sub", "user@example.com"); + body.put("aud", 12345); // Integer instead of String or Array + httpResponse.setBody(body.toJSONString()); + + assertThatThrownBy(() -> client.parse(httpResponse)) + .isInstanceOf(ParseException.class) + .hasMessageContaining("Unsupported 'aud' claim type in /userinfo response"); + } + + @Test + public void testParseWithNoAudienceClaim() + throws Exception + { + NimbusOAuth2Client client = createClient("sub", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("sub", "user@example.com"); + // No "aud" claim - should be allowed + httpResponse.setBody(body.toJSONString()); + + UserInfoResponse response = client.parse(httpResponse); + + assertThat(response.indicatesSuccess()).isTrue(); + } + + @Test + public void testParseWithAudienceArrayContainingNonStrings() + throws Exception + { + NimbusOAuth2Client client = createClient("sub", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("sub", "user@example.com"); + JSONArray audiences = new JSONArray(); + audiences.add(CLIENT_ID); + audiences.add(123); // Non-string element should be filtered out + audiences.add(null); // Null element should be filtered out + body.put("aud", audiences); + httpResponse.setBody(body.toJSONString()); + + UserInfoResponse response = client.parse(httpResponse); + + assertThat(response.indicatesSuccess()).isTrue(); + } + + @Test + public void testParseErrorResponse() + throws Exception + { + NimbusOAuth2Client client = createClient("sub", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = new HTTPResponse(401); + httpResponse.setContentType("application/json"); + JSONObject body = new JSONObject(); + body.put("error", "invalid_token"); + body.put("error_description", "The access token is invalid"); + httpResponse.setBody(body.toJSONString()); + + UserInfoResponse response = client.parse(httpResponse); + assertThat(response.indicatesSuccess()).isFalse(); + UserInfoErrorResponse errorResponse = response.toErrorResponse(); + assertThat(errorResponse.getErrorObject().getCode()).isEqualTo("invalid_token"); + assertThat(errorResponse.getErrorObject().getDescription()).isEqualTo("The access token is invalid"); + } + + @Test + public void testParseWithEmptyAudienceArray() + throws Exception + { + NimbusOAuth2Client client = createClient("sub", ADDITIONAL_AUDIENCE); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("sub", "user@example.com"); + body.put("aud", new JSONArray()); // Empty array + httpResponse.setBody(body.toJSONString()); + + assertThatThrownBy(() -> client.parse(httpResponse)) + .isInstanceOf(ParseException.class) + .hasMessageContaining("Invalid audience in /userinfo response"); + } + + @Test + public void testParseWithMultipleAdditionalAudiences() + throws Exception + { + NimbusOAuth2Client client = createClient("sub", "aud1,aud2,aud3"); + HTTPResponse httpResponse = createSuccessResponse(); + JSONObject body = new JSONObject(); + body.put("sub", "user@example.com"); + body.put("aud", "aud2"); // One of the additional audiences + httpResponse.setBody(body.toJSONString()); + + UserInfoResponse response = client.parse(httpResponse); + + assertThat(response.indicatesSuccess()).isTrue(); + } + + private HTTPResponse createSuccessResponse() throws ParseException + { + HTTPResponse response = new HTTPResponse(200); + response.setContentType("application/json"); + return response; + } +} + +// Made with Bob diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2Config.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2Config.java index bb2735b95e7f6..94f70bb76f4b0 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2Config.java +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2Config.java @@ -47,7 +47,9 @@ public void testDefaults() .setUserMappingPattern(null) .setUserMappingFile(null) .setEnableRefreshTokens(false) - .setEnableDiscovery(true)); + .setEnableDiscovery(true) + .setUserinfoCacheEnabled(false) + .setUserinfoCacheTtl(new Duration(10, MINUTES))); } @Test @@ -70,6 +72,8 @@ public void testExplicitPropertyMappings() .put("http-server.authentication.oauth2.user-mapping.file", userMappingFile.toString()) .put("http-server.authentication.oauth2.refresh-tokens", "true") .put("http-server.authentication.oauth2.oidc.discovery", "false") + .put("http-server.authentication.oauth2.userinfo-cache", "true") + .put("http-server.authentication.oauth2.userinfo-cache-ttl", "5m") .build(); OAuth2Config expected = new OAuth2Config() @@ -86,7 +90,9 @@ public void testExplicitPropertyMappings() .setUserMappingPattern("(.*)@something") .setUserMappingFile(userMappingFile.toFile()) .setEnableRefreshTokens(true) - .setEnableDiscovery(false); + .setEnableDiscovery(false) + .setUserinfoCacheEnabled(true) + .setUserinfoCacheTtl(new Duration(5, MINUTES)); assertFullMapping(properties, expected); }