Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion core/src/main/java/org/apache/iceberg/rest/auth/AuthConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ default Long expiresAtMillis() {

@Value.Default
default boolean keepRefreshed() {
return true;
return OAuth2Properties.TOKEN_REFRESH_ENABLED_DEFAULT;
}

@Value.Default
default boolean exchangeEnabled() {
return OAuth2Properties.TOKEN_EXCHANGE_ENABLED_DEFAULT;
}

@Nullable
Expand Down Expand Up @@ -84,6 +89,11 @@ static AuthConfig fromProperties(Map<String, String> properties) {
properties,
OAuth2Properties.TOKEN_REFRESH_ENABLED,
OAuth2Properties.TOKEN_REFRESH_ENABLED_DEFAULT))
.exchangeEnabled(
PropertyUtil.propertyAsBoolean(
properties,
OAuth2Properties.TOKEN_EXCHANGE_ENABLED,
OAuth2Properties.TOKEN_EXCHANGE_ENABLED_DEFAULT))
.expiresAtMillis(expiresAtMillis(properties))
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ private OAuth2Properties() {}

public static final boolean TOKEN_REFRESH_ENABLED_DEFAULT = true;

/**
* Some IDPs do not support token exchange which is the first approach used for acquiring a new
* token. Disabling this will allow fallback to the client credential flow without initiating a
* token exchange flow.
*/
public static final String TOKEN_EXCHANGE_ENABLED = "token-exchange-enabled";

public static final boolean TOKEN_EXCHANGE_ENABLED_DEFAULT = true;

/** Additional scope for OAuth2. */
public static final String SCOPE = "scope";

Expand Down
69 changes: 34 additions & 35 deletions core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -157,29 +157,37 @@ public static Map<String, String> buildOptionalParam(

private static OAuthTokenResponse refreshToken(
RESTClient client,
AuthConfig config,
Map<String, String> headers,
String subjectToken,
String subjectTokenType,
String scope,
String oauth2ServerUri,
Map<String, String> optionalOAuthParams) {
Map<String, String> request =
tokenExchangeRequest(
subjectToken,
subjectTokenType,
scope != null ? ImmutableList.of(scope) : ImmutableList.of(),
optionalOAuthParams);
if (config.exchangeEnabled()) {
Map<String, String> request =
tokenExchangeRequest(
subjectToken,
subjectTokenType,
config.scope() != null ? ImmutableList.of(config.scope()) : ImmutableList.of(),
optionalOAuthParams);

OAuthTokenResponse response =
client.postForm(
oauth2ServerUri,
request,
OAuthTokenResponse.class,
headers,
ErrorHandlers.oauthErrorHandler());
response.validate();
OAuthTokenResponse response =
client.postForm(
config.oauth2ServerUri(),
request,
OAuthTokenResponse.class,
headers,
ErrorHandlers.oauthErrorHandler());
response.validate();

return response;
} else {
if (null == config.credential()) {
return null;
}

return response;
return fetchToken(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to check whether we have a credential here? I think it's possible that we started just with a token and so we'd end up with null credential. We might also want to add a test for this in TestRESTCatalog, since we already have some other tests where e.g. we either start with a credential or with a token only

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a check, but that's really a configuration issue since if you only have a token, then your only option is exchange flow. Disabling exchange and not having a client credential would also mean that you need to disable refresh (you won't be able to get a new token regardless).

client, Map.of(), config.credential(), config.scope(), config.oauth2ServerUri());
}
}

public static OAuthTokenResponse exchangeToken(
Expand Down Expand Up @@ -593,32 +601,23 @@ private OAuthTokenResponse refreshCurrentToken(RESTClient client) {
return refreshExpiredToken(client);
} else {
// attempt a normal refresh
return refreshToken(
client,
headers(),
token(),
tokenType(),
scope(),
oauth2ServerUri(),
optionalOAuthParams());
return refreshToken(client, config, headers(), token(), tokenType(), optionalOAuthParams());
}
}

private OAuthTokenResponse refreshExpiredToken(RESTClient client) {
if (credential() != null) {
if (credential() == null) {
return null;
}

if (config.exchangeEnabled()) {
Map<String, String> basicHeaders =
RESTUtil.merge(headers(), basicAuthHeaders(credential()));
return refreshToken(
client,
basicHeaders,
token(),
tokenType(),
scope(),
oauth2ServerUri(),
optionalOAuthParams());
client, config, basicHeaders, token(), tokenType(), optionalOAuthParams());
} else {
return fetchToken(client, Map.of(), credential(), scope(), oauth2ServerUri());
}

return null;
}

/**
Expand Down
143 changes: 143 additions & 0 deletions core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,87 @@ public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) {
});
}

@ParameterizedTest
@ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"})
public void testCatalogTokenRefreshExchangeDisabled(String oauth2ServerUri) {
Map<String, String> emptyHeaders = ImmutableMap.of();
Map<String, String> catalogHeaders =
ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog");

Map<String, String> refreshRequest =
ImmutableMap.of(
"grant_type", "client_credentials",
"client_id", "catalog",
"client_secret", "secret",
"scope", "catalog");

RESTCatalogAdapter adapter = Mockito.spy(new RESTCatalogAdapter(backendCatalog));

Answer<OAuthTokenResponse> addOneSecondExpiration =
invocation -> {
OAuthTokenResponse response = (OAuthTokenResponse) invocation.callRealMethod();
return OAuthTokenResponse.builder()
.withToken(response.token())
.withTokenType(response.tokenType())
.withIssuedTokenType(response.issuedTokenType())
.addScopes(response.scopes())
.setExpirationInSeconds(1)
.build();
};

Mockito.doAnswer(addOneSecondExpiration)
.when(adapter)
.postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any());

Map<String, String> contextCredentials = ImmutableMap.of();
SessionCatalog.SessionContext context =
new SessionCatalog.SessionContext(
UUID.randomUUID().toString(), "user", contextCredentials, ImmutableMap.of());

RESTCatalog catalog = new RESTCatalog(context, (config) -> adapter);
catalog.initialize(
"prod",
ImmutableMap.of(
CatalogProperties.URI,
"ignored",
OAuth2Properties.CREDENTIAL,
"catalog:secret",
OAuth2Properties.TOKEN_EXCHANGE_ENABLED,
"false",
OAuth2Properties.OAUTH2_SERVER_URI,
oauth2ServerUri));

Awaitility.await()
.atMost(5, TimeUnit.SECONDS)
.untilAsserted(
() -> {
// call client credentials with no initial auth
Mockito.verify(adapter)
.execute(
reqMatcher(HTTPMethod.POST, oauth2ServerUri, emptyHeaders),
eq(OAuthTokenResponse.class),
any(),
any());

// use the client credential token for config
Mockito.verify(adapter)
.execute(
reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders),
eq(ConfigResponse.class),
any(),
any());

// verify the new token request is issued
Mockito.verify(adapter)
.execute(
reqMatcher(
HTTPMethod.POST, oauth2ServerUri, emptyHeaders, Map.of(), refreshRequest),
eq(OAuthTokenResponse.class),
any(),
any());
});
}

@Test
public void testCatalogExpiredBearerTokenRefreshWithoutCredential() {
// expires at epoch second = 1
Expand Down Expand Up @@ -1667,6 +1748,68 @@ public void testCatalogExpiredBearerTokenIsRefreshedWithCredential(String oauth2
any());
}

@ParameterizedTest
@ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"})
public void testCatalogExpiredTokenCredentialRefreshWithExchangeDisabled(String oauth2ServerUri) {
// expires at epoch second = 1
String token =
".eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJleHAiOjF9.";
Map<String, String> emptyHeaders = ImmutableMap.of();
Map<String, String> catalogHeaders =
ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog");
RESTCatalogAdapter adapter = Mockito.spy(new RESTCatalogAdapter(backendCatalog));

String credential = "catalog:12345";
Map<String, String> contextCredentials = ImmutableMap.of("token", token);
SessionCatalog.SessionContext context =
new SessionCatalog.SessionContext(
UUID.randomUUID().toString(), "user", contextCredentials, ImmutableMap.of());

RESTCatalog catalog = new RESTCatalog(context, (config) -> adapter);
catalog.initialize(
"prod",
ImmutableMap.of(
CatalogProperties.URI,
"ignored",
OAuth2Properties.CREDENTIAL,
credential,
OAuth2Properties.TOKEN_EXCHANGE_ENABLED,
"false",
OAuth2Properties.OAUTH2_SERVER_URI,
oauth2ServerUri));

// call client credentials with no initial auth
Map<String, String> clientCredentialsRequest =
ImmutableMap.of(
"grant_type", "client_credentials",
"client_id", "catalog",
"client_secret", "12345",
"scope", "catalog");

Mockito.verify(adapter)
.execute(
reqMatcher(
HTTPMethod.POST, oauth2ServerUri, emptyHeaders, Map.of(), clientCredentialsRequest),
eq(OAuthTokenResponse.class),
any(),
any());

Mockito.verify(adapter)
.execute(
reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders),
eq(ConfigResponse.class),
any(),
any());

Mockito.verify(adapter)
.execute(
reqMatcher(
HTTPMethod.POST, oauth2ServerUri, emptyHeaders, Map.of(), clientCredentialsRequest),
eq(OAuthTokenResponse.class),
any(),
any());
}

@Test
public void testCatalogValidBearerTokenIsNotRefreshed() {
// expires at epoch second = 19999999999
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,20 @@
*/
package org.apache.iceberg.rest.auth;

import static org.apache.hadoop.hdfs.web.oauth2.OAuth2Constants.BEARER;
import static org.apache.hadoop.hdfs.web.oauth2.OAuth2Constants.CLIENT_CREDENTIALS;
import static org.apache.hadoop.hdfs.web.oauth2.OAuth2Constants.GRANT_TYPE;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.argThat;

import java.io.IOException;
import java.util.Map;
import org.apache.iceberg.rest.RESTClient;
import org.apache.iceberg.rest.responses.OAuthTokenResponse;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

public class TestOAuth2Util {
@Test
Expand Down Expand Up @@ -90,4 +101,38 @@ public void testExpiresAt() {
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
assertThat(OAuth2Util.expiresAtMillis(token)).isNull();
}

@Test
public void testCredentialFlowForSessionRefresh() throws IOException {
AuthConfig authConfig =
ImmutableAuthConfig.builder()
.keepRefreshed(true)
.exchangeEnabled(false)
.token("test_token")
.credential("testClientId:testClientSecret")
.oauth2ServerUri("/v1/token")
.expiresAtMillis(System.currentTimeMillis())
.build();

OAuthTokenResponse response =
OAuthTokenResponse.builder().withToken("refreshed_token").withTokenType(BEARER).build();

try (RESTClient client = Mockito.mock(RESTClient.class);
OAuth2Util.AuthSession session = new OAuth2Util.AuthSession(Map.of(), authConfig); ) {
Mockito.when(client.postForm(any(), anyMap(), any(), anyMap(), any())).thenReturn(response);

session.refresh(client);

Mockito.verify(client)
.postForm(
any(),
argThat(
formData ->
formData.containsKey(GRANT_TYPE)
&& formData.get(GRANT_TYPE).equals(CLIENT_CREDENTIALS)),
Mockito.eq(OAuthTokenResponse.class),
anyMap(),
any());
}
}
}