diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java b/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java index f284822161ca..cdbdfb3d869e 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java @@ -118,6 +118,11 @@ public String oauth2ServerUri() { return properties().getOrDefault(OAuth2Properties.OAUTH2_SERVER_URI, ResourcePaths.tokens()); } + @Value.Lazy + public Map optionalOAuthParams() { + return OAuth2Util.buildOptionalParam(properties()); + } + /** A Bearer token supplier which will be used for interaction with the server. */ @Value.Default public Supplier token() { @@ -207,7 +212,13 @@ private AuthSession authSession() { token, expiresAtMillis(properties()), new AuthSession( - ImmutableMap.of(), token, null, credential(), SCOPE, oauth2ServerUri()))); + ImmutableMap.of(), + token, + null, + credential(), + SCOPE, + oauth2ServerUri(), + optionalOAuthParams()))); } if (credentialProvided()) { @@ -217,11 +228,22 @@ private AuthSession authSession() { id -> { AuthSession session = new AuthSession( - ImmutableMap.of(), null, null, credential(), SCOPE, oauth2ServerUri()); + ImmutableMap.of(), + null, + null, + credential(), + SCOPE, + oauth2ServerUri(), + optionalOAuthParams()); long startTimeMillis = System.currentTimeMillis(); OAuthTokenResponse authResponse = OAuth2Util.fetchToken( - httpClient(), session.headers(), credential(), SCOPE, oauth2ServerUri()); + httpClient(), + session.headers(), + credential(), + SCOPE, + oauth2ServerUri(), + optionalOAuthParams()); return AuthSession.fromTokenResponse( httpClient(), tokenRefreshExecutor(), authResponse, startTimeMillis, session); }); diff --git a/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java b/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java index a34f738c318e..96aa14b128da 100644 --- a/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java +++ b/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java @@ -179,6 +179,7 @@ public void initialize(String name, Map unresolved) { OAuthTokenResponse authResponse; String credential = props.get(OAuth2Properties.CREDENTIAL); String scope = props.getOrDefault(OAuth2Properties.SCOPE, OAuth2Properties.CATALOG_SCOPE); + Map optionalOAuthParams = OAuth2Util.buildOptionalParam(props); String oauth2ServerUri = props.getOrDefault(OAuth2Properties.OAUTH2_SERVER_URI, ResourcePaths.tokens()); try (RESTClient initClient = clientBuilder.apply(props)) { @@ -186,7 +187,8 @@ public void initialize(String name, Map unresolved) { RESTUtil.merge(configHeaders(props), OAuth2Util.authHeaders(initToken)); if (credential != null && !credential.isEmpty()) { authResponse = - OAuth2Util.fetchToken(initClient, initHeaders, credential, scope, oauth2ServerUri); + OAuth2Util.fetchToken( + initClient, initHeaders, credential, scope, oauth2ServerUri, optionalOAuthParams); Map authHeaders = RESTUtil.merge(initHeaders, OAuth2Util.authHeaders(authResponse.token())); config = fetchConfig(initClient, authHeaders, props); @@ -213,7 +215,9 @@ public void initialize(String name, Map unresolved) { this.paths = ResourcePaths.forCatalogProperties(mergedProps); String token = mergedProps.get(OAuth2Properties.TOKEN); - this.catalogAuth = new AuthSession(baseHeaders, null, null, credential, scope, oauth2ServerUri); + this.catalogAuth = + new AuthSession( + baseHeaders, null, null, credential, scope, oauth2ServerUri, optionalOAuthParams); if (authResponse != null) { this.catalogAuth = AuthSession.fromTokenResponse( diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Properties.java b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Properties.java index e1a9181d164d..295e24519129 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Properties.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Properties.java @@ -49,6 +49,12 @@ private OAuth2Properties() {} /** Additional scope for OAuth2. */ public static final String SCOPE = "scope"; + /** Optional param audience for OAuth2. */ + public static final String AUDIENCE = "audience"; + + /** Optional param resource for OAuth2. */ + public static final String RESOURCE = "resource"; + /** Scope for OAuth2 flows. */ public static final String CATALOG_SCOPE = "catalog"; diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java index ad1821a3f2b6..9e36694508d9 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java @@ -40,6 +40,7 @@ import org.apache.iceberg.relocated.com.google.common.base.Splitter; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.rest.ErrorHandlers; import org.apache.iceberg.rest.RESTClient; @@ -129,18 +130,40 @@ public static String toScope(Iterable scopes) { return SCOPE_JOINER.join(scopes); } + public static Map buildOptionalParam(Map properties) { + // these are some options oauth params based on specification + // for any new optional oauth param, define the constant and add the constant to this list + Set optionalParamKeys = + ImmutableSet.of(OAuth2Properties.AUDIENCE, OAuth2Properties.RESOURCE); + ImmutableMap.Builder optionalParamBuilder = ImmutableMap.builder(); + // add scope too, + optionalParamBuilder.put( + OAuth2Properties.SCOPE, + properties.getOrDefault(OAuth2Properties.SCOPE, OAuth2Properties.CATALOG_SCOPE)); + // add all other parameters + for (String key : optionalParamKeys) { + String value = properties.get(key); + if (value != null) { + optionalParamBuilder.put(key, value); + } + } + return optionalParamBuilder.buildKeepingLast(); + } + private static OAuthTokenResponse refreshToken( RESTClient client, Map headers, String subjectToken, String subjectTokenType, String scope, - String oauth2ServerUri) { + String oauth2ServerUri, + Map optionalOAuthParams) { Map request = tokenExchangeRequest( subjectToken, subjectTokenType, - scope != null ? ImmutableList.of(scope) : ImmutableList.of()); + scope != null ? ImmutableList.of(scope) : ImmutableList.of(), + optionalOAuthParams); OAuthTokenResponse response = client.postForm( @@ -162,14 +185,16 @@ public static OAuthTokenResponse exchangeToken( String actorToken, String actorTokenType, String scope, - String oauth2ServerUri) { + String oauth2ServerUri, + Map optionalParams) { Map request = tokenExchangeRequest( subjectToken, subjectTokenType, actorToken, actorTokenType, - scope != null ? ImmutableList.of(scope) : ImmutableList.of()); + scope != null ? ImmutableList.of(scope) : ImmutableList.of(), + optionalParams); OAuthTokenResponse response = client.postForm( @@ -199,7 +224,29 @@ public static OAuthTokenResponse exchangeToken( actorToken, actorTokenType, scope, - ResourcePaths.tokens()); + ResourcePaths.tokens(), + ImmutableMap.of()); + } + + public static OAuthTokenResponse exchangeToken( + RESTClient client, + Map headers, + String subjectToken, + String subjectTokenType, + String actorToken, + String actorTokenType, + String scope, + String oauth2ServerUri) { + return exchangeToken( + client, + headers, + subjectToken, + subjectTokenType, + actorToken, + actorTokenType, + scope, + oauth2ServerUri, + ImmutableMap.of()); } public static OAuthTokenResponse fetchToken( @@ -207,10 +254,13 @@ public static OAuthTokenResponse fetchToken( Map headers, String credential, String scope, - String oauth2ServerUri) { + String oauth2ServerUri, + Map optionalParams) { Map request = clientCredentialsRequest( - credential, scope != null ? ImmutableList.of(scope) : ImmutableList.of()); + credential, + scope != null ? ImmutableList.of(scope) : ImmutableList.of(), + optionalParams); OAuthTokenResponse response = client.postForm( @@ -227,12 +277,27 @@ public static OAuthTokenResponse fetchToken( public static OAuthTokenResponse fetchToken( RESTClient client, Map headers, String credential, String scope) { - return fetchToken(client, headers, credential, scope, ResourcePaths.tokens()); + return fetchToken( + client, headers, credential, scope, ResourcePaths.tokens(), ImmutableMap.of()); + } + + public static OAuthTokenResponse fetchToken( + RESTClient client, + Map headers, + String credential, + String scope, + String oauth2ServerUri) { + + return fetchToken(client, headers, credential, scope, oauth2ServerUri, ImmutableMap.of()); } private static Map tokenExchangeRequest( - String subjectToken, String subjectTokenType, List scopes) { - return tokenExchangeRequest(subjectToken, subjectTokenType, null, null, scopes); + String subjectToken, + String subjectTokenType, + List scopes, + Map optionalOAuthParams) { + return tokenExchangeRequest( + subjectToken, subjectTokenType, null, null, scopes, optionalOAuthParams); } private static Map tokenExchangeRequest( @@ -240,7 +305,8 @@ private static Map tokenExchangeRequest( String subjectTokenType, String actorToken, String actorTokenType, - List scopes) { + List scopes, + Map optionalParams) { Preconditions.checkArgument( VALID_TOKEN_TYPES.contains(subjectTokenType), "Invalid token type: %s", subjectTokenType); Preconditions.checkArgument( @@ -257,8 +323,9 @@ private static Map tokenExchangeRequest( formData.put(ACTOR_TOKEN, actorToken); formData.put(ACTOR_TOKEN_TYPE, actorTokenType); } + formData.putAll(optionalParams); - return formData.build(); + return formData.buildKeepingLast(); } private static Pair parseCredential(String credential) { @@ -278,13 +345,17 @@ private static Pair parseCredential(String credential) { } private static Map clientCredentialsRequest( - String credential, List scopes) { + String credential, List scopes, Map optionalOAuthParams) { Pair credentialPair = parseCredential(credential); - return clientCredentialsRequest(credentialPair.first(), credentialPair.second(), scopes); + return clientCredentialsRequest( + credentialPair.first(), credentialPair.second(), scopes, optionalOAuthParams); } private static Map clientCredentialsRequest( - String clientId, String clientSecret, List scopes) { + String clientId, + String clientSecret, + List scopes, + Map optionalOAuthParams) { ImmutableMap.Builder formData = ImmutableMap.builder(); formData.put(GRANT_TYPE, CLIENT_CREDENTIALS); if (clientId != null) { @@ -292,8 +363,9 @@ private static Map clientCredentialsRequest( } formData.put(CLIENT_SECRET, clientSecret); formData.put(SCOPE, toScope(scopes)); + formData.putAll(optionalOAuthParams); - return formData.build(); + return formData.buildKeepingLast(); } public static String tokenResponseToJson(OAuthTokenResponse response) { @@ -394,13 +466,16 @@ public static class AuthSession { private volatile boolean keepRefreshed = true; private final String oauth2ServerUri; + private Map optionalOAuthParams = ImmutableMap.of(); + public AuthSession( Map baseHeaders, String token, String tokenType, String credential, String scope, - String oauth2ServerUri) { + String oauth2ServerUri, + Map optionalOAuthParams) { this.headers = RESTUtil.merge(baseHeaders, authHeaders(token)); this.token = token; this.tokenType = tokenType; @@ -408,6 +483,7 @@ public AuthSession( this.credential = credential; this.scope = scope; this.oauth2ServerUri = oauth2ServerUri; + this.optionalOAuthParams = optionalOAuthParams; } /** @deprecated since 1.5.0, will be removed in 1.6.0 */ @@ -427,6 +503,25 @@ public AuthSession( this.oauth2ServerUri = ResourcePaths.tokens(); } + /** @deprecated since 1.6.0, will be removed in 1.7.0 */ + @Deprecated + public AuthSession( + Map baseHeaders, + String token, + String tokenType, + String credential, + String scope, + String oauth2ServerUri) { + this.headers = RESTUtil.merge(baseHeaders, authHeaders(token)); + this.token = token; + this.tokenType = tokenType; + this.expiresAtMillis = OAuth2Util.expiresAtMillis(token); + this.credential = credential; + this.scope = scope; + this.oauth2ServerUri = oauth2ServerUri; + this.optionalOAuthParams = ImmutableMap.of(); + } + public Map headers() { return headers; } @@ -459,6 +554,10 @@ public String oauth2ServerUri() { return oauth2ServerUri; } + public Map optionalOAuthParams() { + return optionalOAuthParams; + } + @VisibleForTesting static void setTokenRefreshNumRetries(int retries) { tokenRefreshNumRetries = retries; @@ -471,7 +570,13 @@ static void setTokenRefreshNumRetries(int retries) { */ public static AuthSession empty() { return new AuthSession( - ImmutableMap.of(), null, null, null, OAuth2Properties.CATALOG_SCOPE, null); + ImmutableMap.of(), + null, + null, + null, + OAuth2Properties.CATALOG_SCOPE, + null, + ImmutableMap.of()); } /** @@ -526,14 +631,16 @@ private OAuthTokenResponse refreshCurrentToken(RESTClient client) { return refreshExpiredToken(client); } else { // attempt a normal refresh - return refreshToken(client, headers(), token, tokenType, scope, oauth2ServerUri); + return refreshToken( + client, headers(), token, tokenType, scope, oauth2ServerUri, optionalOAuthParams); } } private OAuthTokenResponse refreshExpiredToken(RESTClient client) { if (credential != null) { Map basicHeaders = RESTUtil.merge(headers(), basicAuthHeaders(credential)); - return refreshToken(client, basicHeaders, token, tokenType, scope, oauth2ServerUri); + return refreshToken( + client, basicHeaders, token, tokenType, scope, oauth2ServerUri, optionalOAuthParams); } return null; @@ -590,7 +697,8 @@ public static AuthSession fromAccessToken( OAuth2Properties.ACCESS_TOKEN_TYPE, parent.credential(), parent.scope(), - parent.oauth2ServerUri()); + parent.oauth2ServerUri(), + parent.optionalOAuthParams()); long startTimeMillis = System.currentTimeMillis(); Long expiresAtMillis = session.expiresAtMillis(); @@ -629,7 +737,12 @@ public static AuthSession fromCredential( long startTimeMillis = System.currentTimeMillis(); OAuthTokenResponse response = fetchToken( - client, parent.headers(), credential, parent.scope(), parent.oauth2ServerUri()); + client, + parent.headers(), + credential, + parent.scope(), + parent.oauth2ServerUri(), + parent.optionalOAuthParams()); return fromTokenResponse(client, executor, response, startTimeMillis, parent, credential); } @@ -657,7 +770,8 @@ private static AuthSession fromTokenResponse( response.issuedTokenType(), credential, parent.scope(), - parent.oauth2ServerUri()); + parent.oauth2ServerUri(), + parent.optionalOAuthParams()); Long expiresAtMillis = session.expiresAtMillis(); if (null == expiresAtMillis && response.expiresInSeconds() != null) { @@ -687,7 +801,8 @@ public static AuthSession fromTokenExchange( parent.token(), parent.tokenType(), parent.scope(), - parent.oauth2ServerUri()); + parent.oauth2ServerUri(), + parent.optionalOAuthParams()); return fromTokenResponse(client, executor, response, startTimeMillis, parent); } } diff --git a/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java b/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java index db67cfd4bb3f..18d832b3cd46 100644 --- a/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java +++ b/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java @@ -678,7 +678,8 @@ public void testClientBearerToken(String oauth2ServerUri) { "urn:ietf:params:oauth:token-type:saml2", "saml2-token", "urn:ietf:params:oauth:token-type:saml1", "saml1-token"), ImmutableMap.of("Authorization", "Bearer client-bearer-token"), - oauth2ServerUri); + oauth2ServerUri, + ImmutableMap.of()); } @ParameterizedTest @@ -694,7 +695,8 @@ public void testClientCredential(String oauth2ServerUri) { "urn:ietf:params:oauth:token-type:saml2", "saml2-token", "urn:ietf:params:oauth:token-type:saml1", "saml1-token"), ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=user"), - oauth2ServerUri); + oauth2ServerUri, + ImmutableMap.of()); } @ParameterizedTest @@ -710,7 +712,8 @@ public void testClientIDToken(String oauth2ServerUri) { "urn:ietf:params:oauth:token-type:saml1", "saml1-token"), ImmutableMap.of( "Authorization", "Bearer token-exchange-token:sub=id-token,act=bearer-token"), - oauth2ServerUri); + oauth2ServerUri, + ImmutableMap.of()); } @ParameterizedTest @@ -725,7 +728,25 @@ public void testClientAccessToken(String oauth2ServerUri) { "urn:ietf:params:oauth:token-type:saml1", "saml1-token"), ImmutableMap.of( "Authorization", "Bearer token-exchange-token:sub=access-token,act=bearer-token"), - oauth2ServerUri); + oauth2ServerUri, + ImmutableMap.of()); + } + + @ParameterizedTest + @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) + public void testClientAccessTokenWithOptionalParams(String oauth2ServerUri) { + testClientAuth( + "bearer-token", + ImmutableMap.of( + "urn:ietf:params:oauth:token-type:access_token", "access-token", + "urn:ietf:params:oauth:token-type:jwt", "jwt-token", + "urn:ietf:params:oauth:token-type:saml2", "saml2-token", + "urn:ietf:params:oauth:token-type:saml1", "saml1-token"), + ImmutableMap.of( + "Authorization", "Bearer token-exchange-token:sub=access-token,act=bearer-token"), + oauth2ServerUri, + ImmutableMap.of( + "scope", "custom_scope", "audience", "test_audience", "resource", "test_resource")); } @ParameterizedTest @@ -739,7 +760,8 @@ public void testClientJWTToken(String oauth2ServerUri) { "urn:ietf:params:oauth:token-type:saml1", "saml1-token"), ImmutableMap.of( "Authorization", "Bearer token-exchange-token:sub=jwt-token,act=bearer-token"), - oauth2ServerUri); + oauth2ServerUri, + ImmutableMap.of()); } @ParameterizedTest @@ -752,7 +774,8 @@ public void testClientSAML2Token(String oauth2ServerUri) { "urn:ietf:params:oauth:token-type:saml1", "saml1-token"), ImmutableMap.of( "Authorization", "Bearer token-exchange-token:sub=saml2-token,act=bearer-token"), - oauth2ServerUri); + oauth2ServerUri, + ImmutableMap.of()); } @ParameterizedTest @@ -763,14 +786,16 @@ public void testClientSAML1Token(String oauth2ServerUri) { ImmutableMap.of("urn:ietf:params:oauth:token-type:saml1", "saml1-token"), ImmutableMap.of( "Authorization", "Bearer token-exchange-token:sub=saml1-token,act=bearer-token"), - oauth2ServerUri); + oauth2ServerUri, + ImmutableMap.of()); } private void testClientAuth( String catalogToken, Map credentials, Map expectedHeaders, - String oauth2ServerUri) { + String oauth2ServerUri, + Map optionalOAuthParams) { Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer " + catalogToken); RESTCatalogAdapter adapter = Mockito.spy(new RESTCatalogAdapter(backendCatalog)); @@ -780,15 +805,16 @@ private void testClientAuth( UUID.randomUUID().toString(), "user", credentials, ImmutableMap.of()); RESTCatalog catalog = new RESTCatalog(context, (config) -> adapter); - catalog.initialize( - "prod", - ImmutableMap.of( - CatalogProperties.URI, - "ignored", - "token", - catalogToken, - OAuth2Properties.OAUTH2_SERVER_URI, - oauth2ServerUri)); + + ImmutableMap.Builder propertyBuilder = ImmutableMap.builder(); + Map initializationProperties = + propertyBuilder + .put(CatalogProperties.URI, "ignored") + .put("token", catalogToken) + .put(OAuth2Properties.OAUTH2_SERVER_URI, oauth2ServerUri) + .putAll(optionalOAuthParams) + .build(); + catalog.initialize("prod", initializationProperties); Assertions.assertThat(catalog.tableExists(TableIdentifier.of("ns", "table"))).isFalse(); @@ -815,7 +841,6 @@ private void testClientAuth( eq(catalogHeaders), any()); } - Mockito.verify(adapter) .execute( eq(HTTPMethod.GET), @@ -825,6 +850,21 @@ private void testClientAuth( eq(LoadTableResponse.class), eq(expectedHeaders), any()); + if (!optionalOAuthParams.isEmpty()) { + Mockito.verify(adapter) + .execute( + eq(HTTPMethod.POST), + eq(oauth2ServerUri), + any(), + Mockito.argThat( + body -> + ((Map) body) + .keySet() + .containsAll(optionalOAuthParams.keySet())), + eq(OAuthTokenResponse.class), + eq(catalogHeaders), + any()); + } } @ParameterizedTest