diff --git a/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4Signer.java b/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4Signer.java index c1c858416352..fc9de32cc6fd 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4Signer.java +++ b/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4Signer.java @@ -51,7 +51,10 @@ *

See Signing AWS * API requests for details about the protocol. + * + * @deprecated since 1.9.0, will be removed in 1.10.0; use {@link RESTSigV4AuthManager} instead. */ +@Deprecated public class RESTSigV4Signer implements HttpRequestInterceptor { static final String EMPTY_BODY_SHA256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java b/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java index 769d187875db..da96f4cb8f48 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java @@ -28,11 +28,10 @@ import org.apache.iceberg.relocated.com.google.common.base.Strings; import org.apache.iceberg.rest.ErrorHandlers; import org.apache.iceberg.rest.HTTPClient; -import org.apache.iceberg.rest.HTTPHeaders; import org.apache.iceberg.rest.RESTClient; -import org.apache.iceberg.rest.auth.DefaultAuthSession; -import org.apache.iceberg.rest.auth.OAuth2Properties; -import org.apache.iceberg.rest.auth.OAuth2Util; +import org.apache.iceberg.rest.auth.AuthManager; +import org.apache.iceberg.rest.auth.AuthManagers; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.credentials.Credential; import org.apache.iceberg.rest.responses.LoadCredentialsResponse; import software.amazon.awssdk.auth.credentials.AwsCredentials; @@ -48,6 +47,8 @@ public class VendedCredentialsProvider implements AwsCredentialsProvider, SdkAut private volatile HTTPClient client; private final Map properties; private final CachedSupplier credentialCache; + private AuthManager authManager; + private AuthSession authSession; private VendedCredentialsProvider(Map properties) { Preconditions.checkArgument(null != properties, "Invalid properties: null"); @@ -66,8 +67,10 @@ public AwsCredentials resolveCredentials() { @Override public void close() { - IoUtils.closeQuietly(client, null); - credentialCache.close(); + IoUtils.closeQuietlyV2(authSession, null); + IoUtils.closeQuietlyV2(authManager, null); + IoUtils.closeQuietlyV2(client, null); + IoUtils.closeQuietlyV2(credentialCache, null); } public static VendedCredentialsProvider create(Map properties) { @@ -78,14 +81,10 @@ private RESTClient httpClient() { if (null == client) { synchronized (this) { if (null == client) { - DefaultAuthSession authSession = - DefaultAuthSession.of( - HTTPHeaders.of(OAuth2Util.authHeaders(properties.get(OAuth2Properties.TOKEN)))); - client = - HTTPClient.builder(properties) - .uri(properties.get(URI)) - .withAuthSession(authSession) - .build(); + authManager = AuthManagers.loadAuthManager("s3-credentials-refresh", properties); + HTTPClient httpClient = HTTPClient.builder(properties).uri(properties.get(URI)).build(); + authSession = authManager.catalogSession(httpClient, properties); + client = httpClient.withAuthSession(authSession); } } } diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java b/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java index 8f5a00b49daf..647291d394cd 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java @@ -20,15 +20,12 @@ import com.github.benmanes.caffeine.cache.Cache; import com.github.benmanes.caffeine.cache.Caffeine; -import com.github.benmanes.caffeine.cache.RemovalListener; import java.io.IOException; import java.io.InputStream; import java.net.URI; -import java.time.Duration; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Supplier; @@ -42,13 +39,12 @@ import org.apache.iceberg.rest.HTTPClient; import org.apache.iceberg.rest.RESTClient; import org.apache.iceberg.rest.ResourcePaths; -import org.apache.iceberg.rest.auth.AuthConfig; +import org.apache.iceberg.rest.auth.AuthManager; +import org.apache.iceberg.rest.auth.AuthManagers; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.auth.OAuth2Properties; import org.apache.iceberg.rest.auth.OAuth2Util; -import org.apache.iceberg.rest.auth.OAuth2Util.AuthSession; -import org.apache.iceberg.rest.responses.OAuthTokenResponse; import org.apache.iceberg.util.PropertyUtil; -import org.apache.iceberg.util.ThreadPools; import org.immutables.value.Value; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -64,7 +60,7 @@ @Value.Immutable public abstract class S3V4RestSignerClient - extends AbstractAws4Signer { + extends AbstractAws4Signer implements AutoCloseable { private static final Logger LOG = LoggerFactory.getLogger(S3V4RestSignerClient.class); public static final String S3_SIGNER_URI = "s3.signer.uri"; @@ -81,13 +77,14 @@ public abstract class S3V4RestSignerClient private static final String SCOPE = "sign"; @SuppressWarnings("immutables:incompat") - private static volatile ScheduledExecutorService tokenRefreshExecutor; + private volatile AuthManager authManager; - @SuppressWarnings("immutables:incompat") - private static volatile RESTClient httpClient; + @SuppressWarnings({"immutables:incompat", "VisibilityModifier"}) + @VisibleForTesting + static volatile RESTClient httpClient; @SuppressWarnings("immutables:incompat") - private static volatile Cache authSessionCache; + private volatile AuthSession authSession; public abstract Map properties(); @@ -138,52 +135,6 @@ boolean keepTokenRefreshed() { OAuth2Properties.TOKEN_REFRESH_ENABLED_DEFAULT); } - @VisibleForTesting - ScheduledExecutorService tokenRefreshExecutor() { - if (!keepTokenRefreshed()) { - return null; - } - - if (null == tokenRefreshExecutor) { - synchronized (S3V4RestSignerClient.class) { - if (null == tokenRefreshExecutor) { - tokenRefreshExecutor = ThreadPools.newScheduledPool("s3-signer-token-refresh", 1); - } - } - } - - return tokenRefreshExecutor; - } - - private Cache authSessionCache() { - if (null == authSessionCache) { - synchronized (S3V4RestSignerClient.class) { - if (null == authSessionCache) { - long expirationIntervalMs = - PropertyUtil.propertyAsLong( - properties(), - CatalogProperties.AUTH_SESSION_TIMEOUT_MS, - CatalogProperties.AUTH_SESSION_TIMEOUT_MS_DEFAULT); - - authSessionCache = - Caffeine.newBuilder() - .expireAfterAccess(Duration.ofMillis(expirationIntervalMs)) - .removalListener( - (RemovalListener) - (id, auth, cause) -> { - if (null != auth) { - LOG.trace("Stopping refresh for AuthSession"); - auth.stopRefreshing(); - } - }) - .build(); - } - } - } - - return authSessionCache; - } - private RESTClient httpClient() { if (null == httpClient) { synchronized (S3V4RestSignerClient.class) { @@ -200,86 +151,40 @@ private RESTClient httpClient() { return httpClient; } - private AuthSession authSession() { - String token = token().get(); - if (null != token) { - return authSessionCache() - .get( - token, - id -> { - // this client will be reused for token refreshes; it must contain an empty auth - // session in order to avoid interfering with refreshed tokens - RESTClient refreshClient = - httpClient().withAuthSession(org.apache.iceberg.rest.auth.AuthSession.EMPTY); - return AuthSession.fromAccessToken( - refreshClient, - tokenRefreshExecutor(), - token, - expiresAtMillis(properties()), - new AuthSession( - ImmutableMap.of(), - AuthConfig.builder() - .token(token) - .credential(credential()) - .scope(SCOPE) - .oauth2ServerUri(oauth2ServerUri()) - .optionalOAuthParams(optionalOAuthParams()) - .build())); - }); - } - - if (credentialProvided()) { - return authSessionCache() - .get( - credential(), - id -> { - AuthSession session = - new AuthSession( - ImmutableMap.of(), - AuthConfig.builder() - .credential(credential()) - .scope(SCOPE) - .oauth2ServerUri(oauth2ServerUri()) - .optionalOAuthParams(optionalOAuthParams()) - .build()); - long startTimeMillis = System.currentTimeMillis(); - // this client will be reused for token refreshes; it must contain an empty auth - // session in order to avoid interfering with refreshed tokens - RESTClient refreshClient = - httpClient().withAuthSession(org.apache.iceberg.rest.auth.AuthSession.EMPTY); - OAuthTokenResponse authResponse = - OAuth2Util.fetchToken( - refreshClient, - session.headers(), - credential(), - SCOPE, - oauth2ServerUri(), - optionalOAuthParams()); - return AuthSession.fromTokenResponse( - refreshClient, tokenRefreshExecutor(), authResponse, startTimeMillis, session); - }); + @VisibleForTesting + AuthSession authSession() { + if (null == authSession) { + synchronized (S3V4RestSignerClient.class) { + if (null == authSession) { + authManager = AuthManagers.loadAuthManager("s3-signer", properties()); + ImmutableMap.Builder properties = + ImmutableMap.builder() + .putAll(properties()) + .putAll(optionalOAuthParams()) + .put(OAuth2Properties.OAUTH2_SERVER_URI, oauth2ServerUri()) + .put(OAuth2Properties.TOKEN_REFRESH_ENABLED, String.valueOf(keepTokenRefreshed())) + .put(OAuth2Properties.SCOPE, SCOPE); + String token = token().get(); + if (null != token) { + properties.put(OAuth2Properties.TOKEN, token); + } + + if (credentialProvided()) { + properties.put(OAuth2Properties.CREDENTIAL, credential()); + } + + authSession = authManager.tableSession(httpClient(), properties.buildKeepingLast()); + } + } } - return AuthSession.empty(); + return authSession; } private boolean credentialProvided() { return null != credential() && !credential().isEmpty(); } - private Long expiresAtMillis(Map properties) { - if (properties.containsKey(OAuth2Properties.TOKEN_EXPIRES_IN_MS)) { - long expiresInMillis = - PropertyUtil.propertyAsLong( - properties, - OAuth2Properties.TOKEN_EXPIRES_IN_MS, - OAuth2Properties.TOKEN_EXPIRES_IN_MS_DEFAULT); - return System.currentTimeMillis() + expiresInMillis; - } else { - return null; - } - } - @Value.Check protected void check() { Preconditions.checkArgument( @@ -377,6 +282,12 @@ public SdkHttpFullRequest sign( return mutableRequest.build(); } + @Override + public void close() throws Exception { + IoUtils.closeQuietlyV2(authSession, null); + IoUtils.closeQuietlyV2(authManager, null); + } + /** * Only add body for DeleteObjectsRequest. Refer to * https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html#API_DeleteObjects_RequestSyntax diff --git a/aws/src/test/java/org/apache/iceberg/aws/TestRESTSigV4Signer.java b/aws/src/test/java/org/apache/iceberg/aws/TestRESTSigV4Signer.java index cc8873e30ea7..462663711225 100644 --- a/aws/src/test/java/org/apache/iceberg/aws/TestRESTSigV4Signer.java +++ b/aws/src/test/java/org/apache/iceberg/aws/TestRESTSigV4Signer.java @@ -27,6 +27,10 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.rest.HTTPClient; +import org.apache.iceberg.rest.RESTClient; +import org.apache.iceberg.rest.auth.AuthManager; +import org.apache.iceberg.rest.auth.AuthManagers; +import org.apache.iceberg.rest.auth.AuthProperties; import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.auth.OAuth2Util; import org.apache.iceberg.rest.responses.ConfigResponse; @@ -39,12 +43,15 @@ import org.mockserver.model.Header; import org.mockserver.model.HttpRequest; import org.mockserver.model.HttpResponse; +import org.mockserver.model.Parameter; +import org.mockserver.model.ParameterBody; import org.mockserver.verify.VerificationTimes; import software.amazon.awssdk.auth.signer.internal.SignerConstant; public class TestRESTSigV4Signer { private static ClientAndServer mockServer; - private static HTTPClient client; + private static RESTClient client; + private static AuthManager authManager; @BeforeAll public static void beforeClass() { @@ -52,26 +59,35 @@ public static void beforeClass() { Map properties = ImmutableMap.of( - "rest.sigv4-enabled", - "true", + AuthProperties.AUTH_TYPE, + AuthProperties.AUTH_TYPE_SIGV4, // CI environment doesn't have credentials, but a value must be set for signing AwsProperties.REST_SIGNER_REGION, "us-west-2", AwsProperties.REST_ACCESS_KEY_ID, "id", AwsProperties.REST_SECRET_ACCESS_KEY, - "secret"); - client = + "secret", + // OAuth2 token to test relocation of conflicting auth header + "token", + "existing_token"); + + HTTPClient httpClient = HTTPClient.builder(properties) .uri("http://localhost:" + mockServer.getLocalPort()) - .withHeader(HttpHeaders.AUTHORIZATION, "Bearer existing_token") - .withAuthSession(AuthSession.EMPTY) - .build(); + .build() + .withAuthSession(AuthSession.EMPTY); + + authManager = AuthManagers.loadAuthManager("test", properties); + AuthSession authSession = authManager.catalogSession(httpClient, properties); + + client = httpClient.withAuthSession(authSession); } @AfterAll public static void afterClass() throws IOException { mockServer.stop(); + authManager.close(); client.close(); } @@ -90,11 +106,13 @@ public void signRequestWithoutBody() { .withHeader(Header.header(HttpHeaders.AUTHORIZATION, "AWS4-HMAC-SHA256.*")) // Require that conflicting auth header is relocated .withHeader( - Header.header(RESTSigV4Signer.RELOCATED_HEADER_PREFIX + HttpHeaders.AUTHORIZATION)) + Header.header( + RESTSigV4AuthSession.RELOCATED_HEADER_PREFIX + HttpHeaders.AUTHORIZATION, + "Bearer existing_token")) // Require the empty body checksum .withHeader( Header.header( - SignerConstant.X_AMZ_CONTENT_SHA256, RESTSigV4Signer.EMPTY_BODY_SHA256)); + SignerConstant.X_AMZ_CONTENT_SHA256, RESTSigV4AuthSession.EMPTY_BODY_SHA256)); mockServer .when(request) @@ -113,11 +131,18 @@ public void signRequestWithBody() { HttpRequest.request() .withMethod("POST") .withPath("/v1/oauth/token") + .withBody( + ParameterBody.params( + Parameter.param("client_id", "asdfasd"), + Parameter.param("client_secret", "asdfasdf"), + Parameter.param("scope", "catalog"))) // Require SigV4 Authorization .withHeader(Header.header(HttpHeaders.AUTHORIZATION, "AWS4-HMAC-SHA256.*")) // Require that conflicting auth header is relocated .withHeader( - Header.header(RESTSigV4Signer.RELOCATED_HEADER_PREFIX + HttpHeaders.AUTHORIZATION)) + Header.header( + RESTSigV4AuthSession.RELOCATED_HEADER_PREFIX + HttpHeaders.AUTHORIZATION, + "Bearer existing_token")) // Require a body checksum is set .withHeader(Header.header(SignerConstant.X_AMZ_CONTENT_SHA256)); diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3RestSigner.java b/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3RestSigner.java index 313214c4e98f..34f5f2c710c8 100644 --- a/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3RestSigner.java +++ b/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3RestSigner.java @@ -19,6 +19,7 @@ package org.apache.iceberg.aws.s3.signer; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.type; import java.net.URI; import java.nio.file.Path; @@ -112,11 +113,6 @@ public static void beforeClass() throws Exception { @AfterAll public static void afterClass() throws Exception { - assertThat(validatingSigner.icebergSigner.tokenRefreshExecutor()) - .isInstanceOf(ScheduledThreadPoolExecutor.class); - - ScheduledThreadPoolExecutor executor = - ((ScheduledThreadPoolExecutor) validatingSigner.icebergSigner.tokenRefreshExecutor()); // token expiration is set to 10000s by the S3SignerServlet so there should be exactly one token // scheduled for refresh. Such a high token expiration value is explicitly selected to be much // larger than TestS3RestSigner would need to execute all tests. @@ -124,16 +120,23 @@ public static void afterClass() throws Exception { // there aren't other token refreshes being scheduled after every sign request and after // TestS3RestSigner completes all tests, there should be only this single token in the queue // that is scheduled for refresh - assertThat(executor.getPoolSize()).isEqualTo(1); - assertThat(executor.getQueue()) - .as("should only have a single token scheduled for refresh") - .hasSize(1); - assertThat(executor.getActiveCount()) - .as("should not have any token being refreshed") - .isEqualTo(0); - assertThat(executor.getCompletedTaskCount()) - .as("should not have any expired token that required a refresh") - .isEqualTo(0); + assertThat(validatingSigner.icebergSigner) + .extracting("authManager") + .extracting("refreshExecutor") + .asInstanceOf(type(ScheduledThreadPoolExecutor.class)) + .satisfies( + executor -> { + assertThat(executor.getPoolSize()).isEqualTo(1); + assertThat(executor.getQueue()) + .as("should only have a single token scheduled for refresh") + .hasSize(1); + assertThat(executor.getActiveCount()) + .as("should not have any token being refreshed") + .isEqualTo(0); + assertThat(executor.getCompletedTaskCount()) + .as("should not have any expired token that required a refresh") + .isEqualTo(0); + }); if (null != httpServer) { httpServer.stop(); diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3V4RestSignerClient.java b/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3V4RestSignerClient.java new file mode 100644 index 000000000000..821e16443548 --- /dev/null +++ b/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3V4RestSignerClient.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.aws.s3.signer; + +import static org.apache.iceberg.aws.s3.signer.S3V4RestSignerClient.S3_SIGNER_URI; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.type; +import static org.mockito.Mockito.when; + +import java.util.Map; +import java.util.stream.Stream; +import org.apache.iceberg.rest.RESTClient; +import org.apache.iceberg.rest.auth.AuthProperties; +import org.apache.iceberg.rest.auth.AuthSession; +import org.apache.iceberg.rest.auth.OAuth2Properties; +import org.apache.iceberg.rest.auth.OAuth2Util; +import org.apache.iceberg.rest.responses.OAuthTokenResponse; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; + +class TestS3V4RestSignerClient { + + @BeforeAll + static void beforeAll() { + S3V4RestSignerClient.httpClient = Mockito.mock(RESTClient.class); + when(S3V4RestSignerClient.httpClient.withAuthSession(Mockito.any())) + .thenReturn(S3V4RestSignerClient.httpClient); + when(S3V4RestSignerClient.httpClient.postForm( + Mockito.anyString(), + Mockito.eq( + Map.of( + "grant_type", + "client_credentials", + "client_id", + "user", + "client_secret", + "12345", + "scope", + "sign")), + Mockito.eq(OAuthTokenResponse.class), + Mockito.anyMap(), + Mockito.any())) + .thenReturn( + OAuthTokenResponse.builder().withToken("token").withTokenType("Bearer").build()); + } + + @AfterAll + static void afterAll() { + S3V4RestSignerClient.httpClient = null; + } + + @ParameterizedTest + @MethodSource("validOAuth2Properties") + void authSessionOAuth2(Map properties, String expectedToken) throws Exception { + try (S3V4RestSignerClient client = + ImmutableS3V4RestSignerClient.builder().properties(properties).build(); + AuthSession authSession = client.authSession()) { + + if (expectedToken == null) { + assertThat(authSession).isInstanceOf(AuthSession.class); + } else { + assertThat(authSession) + .asInstanceOf(type(OAuth2Util.AuthSession.class)) + .extracting(OAuth2Util.AuthSession::headers) + .satisfies( + headers -> + assertThat(headers).containsEntry("Authorization", "Bearer " + expectedToken)); + } + } + } + + public static Stream validOAuth2Properties() { + return Stream.of( + // No OAuth2 data + Arguments.of(Map.of(S3_SIGNER_URI, "https://signer.com"), null), + // Token only + Arguments.of( + Map.of( + S3_SIGNER_URI, + "https://signer.com", + AuthProperties.AUTH_TYPE, + AuthProperties.AUTH_TYPE_OAUTH2, + OAuth2Properties.TOKEN, + "token"), + "token"), + // Credential only: expect a token to be fetched + Arguments.of( + Map.of( + S3_SIGNER_URI, + "https://signer.com", + AuthProperties.AUTH_TYPE, + AuthProperties.AUTH_TYPE_OAUTH2, + OAuth2Properties.CREDENTIAL, + "user:12345"), + "token"), + // Token and credential: should use token as is, not fetch a new one + Arguments.of( + Map.of( + S3_SIGNER_URI, + "https://signer.com", + AuthProperties.AUTH_TYPE, + AuthProperties.AUTH_TYPE_OAUTH2, + OAuth2Properties.TOKEN, + "token", + OAuth2Properties.CREDENTIAL, + "user:12345"), + "token")); + } +} diff --git a/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java b/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java index dfb1d5827ef9..72c8dc7ab228 100644 --- a/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java +++ b/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java @@ -40,7 +40,6 @@ import org.apache.hc.core5.http.Header; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.HttpHost; -import org.apache.hc.core5.http.HttpRequestInterceptor; import org.apache.hc.core5.http.HttpStatus; import org.apache.hc.core5.http.ParseException; import org.apache.hc.core5.http.impl.EnglishReasonPhraseCatalog; @@ -48,8 +47,6 @@ import org.apache.hc.core5.http.io.entity.StringEntity; import org.apache.hc.core5.io.CloseMode; import org.apache.iceberg.IcebergBuild; -import org.apache.iceberg.common.DynConstructors; -import org.apache.iceberg.common.DynMethods; import org.apache.iceberg.exceptions.RESTException; import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; @@ -65,9 +62,6 @@ public class HTTPClient extends BaseHTTPClient { private static final Logger LOG = LoggerFactory.getLogger(HTTPClient.class); - private static final String SIGV4_ENABLED = "rest.sigv4-enabled"; - private static final String SIGV4_REQUEST_INTERCEPTOR_IMPL = - "org.apache.iceberg.aws.RESTSigV4Signer"; @VisibleForTesting static final String CLIENT_VERSION_HEADER = "X-Client-Version"; @VisibleForTesting @@ -96,7 +90,6 @@ private HTTPClient( CredentialsProvider proxyCredsProvider, Map baseHeaders, ObjectMapper objectMapper, - HttpRequestInterceptor requestInterceptor, Map properties, HttpClientConnectionManager connectionManager, AuthSession session) { @@ -109,10 +102,6 @@ private HTTPClient( clientBuilder.setConnectionManager(connectionManager); - if (requestInterceptor != null) { - clientBuilder.addRequestInterceptorLast(requestInterceptor); - } - int maxRetries = PropertyUtil.propertyAsInt(properties, REST_MAX_RETRIES, 5); clientBuilder.setRetryStrategy(new ExponentialHttpRequestRetryStrategy(maxRetries)); @@ -339,41 +328,6 @@ public void close() throws IOException { } } - @VisibleForTesting - static HttpRequestInterceptor loadInterceptorDynamically( - String impl, Map properties) { - HttpRequestInterceptor instance; - - DynConstructors.Ctor ctor; - try { - ctor = - DynConstructors.builder(HttpRequestInterceptor.class) - .loader(HTTPClient.class.getClassLoader()) - .impl(impl) - .buildChecked(); - } catch (NoSuchMethodException e) { - throw new IllegalArgumentException( - String.format( - "Cannot initialize RequestInterceptor, missing no-arg constructor: %s", impl), - e); - } - - try { - instance = ctor.newInstance(); - } catch (ClassCastException e) { - throw new IllegalArgumentException( - String.format("Cannot initialize, %s does not implement RequestInterceptor", impl), e); - } - - DynMethods.builder("initialize") - .hiddenImpl(impl, Map.class) - .orNoop() - .build(instance) - .invoke(properties); - - return instance; - } - static HttpClientConnectionManager configureConnectionManager(Map properties) { PoolingHttpClientConnectionManagerBuilder connectionManagerBuilder = PoolingHttpClientConnectionManagerBuilder.create(); @@ -489,12 +443,6 @@ public HTTPClient build() { withHeader(CLIENT_VERSION_HEADER, IcebergBuild.fullVersion()); withHeader(CLIENT_GIT_COMMIT_SHORT_HEADER, IcebergBuild.gitCommitShortId()); - HttpRequestInterceptor interceptor = null; - - if (PropertyUtil.propertyAsBoolean(properties, SIGV4_ENABLED, false)) { - interceptor = loadInterceptorDynamically(SIGV4_REQUEST_INTERCEPTOR_IMPL, properties); - } - if (this.proxyCredentialsProvider != null) { Preconditions.checkNotNull( proxy, "Invalid http client proxy for proxy credentials provider: null"); @@ -506,7 +454,6 @@ public HTTPClient build() { proxyCredentialsProvider, baseHeaders, mapper, - interceptor, properties, configureConnectionManager(properties), authSession); diff --git a/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java b/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java index d9542bf4a547..ab56b113d66e 100644 --- a/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java +++ b/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java @@ -18,23 +18,15 @@ */ package org.apache.iceberg.rest; -import com.github.benmanes.caffeine.cache.Cache; -import com.github.benmanes.caffeine.cache.Caffeine; -import com.github.benmanes.caffeine.cache.RemovalListener; import java.io.Closeable; import java.io.IOException; import java.io.UncheckedIOException; -import java.time.Duration; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; -import java.util.concurrent.Future; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.function.Function; -import java.util.function.Supplier; import org.apache.iceberg.BaseTable; import org.apache.iceberg.CatalogProperties; import org.apache.iceberg.CatalogUtil; @@ -70,11 +62,9 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; -import org.apache.iceberg.rest.auth.AuthConfig; -import org.apache.iceberg.rest.auth.DefaultAuthSession; -import org.apache.iceberg.rest.auth.OAuth2Properties; -import org.apache.iceberg.rest.auth.OAuth2Util; -import org.apache.iceberg.rest.auth.OAuth2Util.AuthSession; +import org.apache.iceberg.rest.auth.AuthManager; +import org.apache.iceberg.rest.auth.AuthManagers; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.requests.CommitTransactionRequest; import org.apache.iceberg.rest.requests.CreateNamespaceRequest; import org.apache.iceberg.rest.requests.CreateTableRequest; @@ -92,12 +82,9 @@ import org.apache.iceberg.rest.responses.ListTablesResponse; import org.apache.iceberg.rest.responses.LoadTableResponse; import org.apache.iceberg.rest.responses.LoadViewResponse; -import org.apache.iceberg.rest.responses.OAuthTokenResponse; import org.apache.iceberg.rest.responses.UpdateNamespacePropertiesResponse; import org.apache.iceberg.util.EnvironmentUtil; -import org.apache.iceberg.util.Pair; import org.apache.iceberg.util.PropertyUtil; -import org.apache.iceberg.util.ThreadPools; import org.apache.iceberg.view.BaseView; import org.apache.iceberg.view.ImmutableSQLViewRepresentation; import org.apache.iceberg.view.ImmutableViewVersion; @@ -120,20 +107,6 @@ public class RESTSessionCatalog extends BaseViewSessionCatalog // server supports view endpoints but doesn't send the "endpoints" field in the ConfigResponse static final String VIEW_ENDPOINTS_SUPPORTED = "view-endpoints-supported"; public static final String REST_PAGE_SIZE = "rest-page-size"; - private static final List TOKEN_PREFERENCE_ORDER = - ImmutableList.of( - OAuth2Properties.ID_TOKEN_TYPE, - OAuth2Properties.ACCESS_TOKEN_TYPE, - OAuth2Properties.JWT_TOKEN_TYPE, - OAuth2Properties.SAML2_TOKEN_TYPE, - OAuth2Properties.SAML1_TOKEN_TYPE); - - // Auth-related properties that are allowed to be passed to the table session - private static final Set TABLE_SESSION_ALLOW_LIST = - ImmutableSet.builder() - .add(OAuth2Properties.TOKEN) - .addAll(TOKEN_PREFERENCE_ORDER) - .build(); // these default endpoints must not be updated in order to maintain backwards compatibility with // legacy servers @@ -169,11 +142,9 @@ public class RESTSessionCatalog extends BaseViewSessionCatalog private final Function, RESTClient> clientBuilder; private final BiFunction, FileIO> ioBuilder; - private Cache sessions = null; - private Cache tableSessions = null; private FileIOTracker fileIOTracker = null; private AuthSession catalogAuth = null; - private boolean keepTokenRefreshed = true; + private AuthManager authManager; private RESTClient client = null; private ResourcePaths paths = null; private SnapshotMode snapshotMode = null; @@ -185,9 +156,6 @@ public class RESTSessionCatalog extends BaseViewSessionCatalog private CloseableGroup closeables = null; private Set endpoints; - // a lazy thread pool for token refresh - private volatile ScheduledExecutorService refreshExecutor = null; - enum SnapshotMode { ALL, REFS; @@ -209,7 +177,6 @@ public RESTSessionCatalog( this.ioBuilder = ioBuilder; } - @SuppressWarnings("checkstyle:CyclomaticComplexity") @Override public void initialize(String name, Map unresolved) { Preconditions.checkArgument(unresolved != null, "Invalid configuration: null"); @@ -218,55 +185,18 @@ public void initialize(String name, Map unresolved) { // catalog service Map props = EnvironmentUtil.resolveAll(unresolved); - long startTimeMillis = - System.currentTimeMillis(); // keep track of the init start time for token refresh - String initToken = props.get(OAuth2Properties.TOKEN); - boolean hasInitToken = initToken != null; + this.authManager = AuthManagers.loadAuthManager(name, props); - // fetch auth and config to complete initialization ConfigResponse config; - OAuthTokenResponse authResponse; - String credential = props.get(OAuth2Properties.CREDENTIAL); - boolean hasCredential = credential != null && !credential.isEmpty(); - String scope = props.getOrDefault(OAuth2Properties.SCOPE, OAuth2Properties.CATALOG_SCOPE); - Map optionalOAuthParams = OAuth2Util.buildOptionalParam(props); - if (!props.containsKey(OAuth2Properties.OAUTH2_SERVER_URI) - && (hasInitToken || hasCredential) - && !PropertyUtil.propertyAsBoolean(props, "rest.sigv4-enabled", false)) { - LOG.warn( - "Iceberg REST client is missing the OAuth2 server URI configuration and defaults to {}/{}. " - + "This automatic fallback will be removed in a future Iceberg release." - + "It is recommended to configure the OAuth2 endpoint using the '{}' property to be prepared. " - + "This warning will disappear if the OAuth2 endpoint is explicitly configured. " - + "See https://github.com/apache/iceberg/issues/10537", - RESTUtil.stripTrailingSlash(props.get(CatalogProperties.URI)), - ResourcePaths.tokens(), - OAuth2Properties.OAUTH2_SERVER_URI); - } - String oauth2ServerUri = - props.getOrDefault(OAuth2Properties.OAUTH2_SERVER_URI, ResourcePaths.tokens()); - try (DefaultAuthSession initSession = - DefaultAuthSession.of(HTTPHeaders.of(OAuth2Util.authHeaders(initToken))); - RESTClient initClient = clientBuilder.apply(props).withAuthSession(initSession)) { - Map initHeaders = configHeaders(props); - if (hasCredential) { - authResponse = - OAuth2Util.fetchToken( - initClient, initHeaders, credential, scope, oauth2ServerUri, optionalOAuthParams); - Map authHeaders = - RESTUtil.merge(initHeaders, OAuth2Util.authHeaders(authResponse.token())); - config = fetchConfig(initClient, authHeaders, props); - } else { - authResponse = null; - config = fetchConfig(initClient, initHeaders, props); - } + try (RESTClient initClient = clientBuilder.apply(props); + AuthSession initSession = authManager.initSession(initClient, props)) { + config = fetchConfig(initClient.withAuthSession(initSession), initSession, props); } catch (IOException e) { throw new UncheckedIOException("Failed to close HTTP client", e); } // build the final configuration and set up the catalog's auth Map mergedProps = config.merge(props); - Map baseHeaders = configHeaders(mergedProps); if (config.endpoints().isEmpty()) { this.endpoints = @@ -280,39 +210,10 @@ public void initialize(String name, Map unresolved) { this.endpoints = ImmutableSet.copyOf(config.endpoints()); } - this.sessions = newSessionCache(mergedProps); - this.tableSessions = newSessionCache(mergedProps); - this.keepTokenRefreshed = - PropertyUtil.propertyAsBoolean( - mergedProps, - OAuth2Properties.TOKEN_REFRESH_ENABLED, - OAuth2Properties.TOKEN_REFRESH_ENABLED_DEFAULT); + this.client = clientBuilder.apply(mergedProps); this.paths = ResourcePaths.forCatalogProperties(mergedProps); - String token = mergedProps.get(OAuth2Properties.TOKEN); - this.catalogAuth = - new AuthSession( - baseHeaders, - AuthConfig.builder() - .credential(credential) - .scope(scope) - .oauth2ServerUri(oauth2ServerUri) - .optionalOAuthParams(optionalOAuthParams) - .build()); - - this.client = clientBuilder.apply(mergedProps).withAuthSession(catalogAuth); - - if (authResponse != null) { - this.catalogAuth = - AuthSession.fromTokenResponse( - client, tokenRefreshExecutor(name), authResponse, startTimeMillis, catalogAuth); - this.client = client.withAuthSession(catalogAuth); - } else if (token != null) { - this.catalogAuth = - AuthSession.fromAccessToken( - client, tokenRefreshExecutor(name), token, expiresAtMillis(mergedProps), catalogAuth); - this.client = client.withAuthSession(catalogAuth); - } + this.catalogAuth = authManager.catalogSession(client, mergedProps); this.pageSize = PropertyUtil.propertyAsNullableInt(mergedProps, REST_PAGE_SIZE); if (pageSize != null) { @@ -324,6 +225,8 @@ public void initialize(String name, Map unresolved) { this.fileIOTracker = new FileIOTracker(); this.closeables = new CloseableGroup(); + this.closeables.addCloseable(this.catalogAuth); + this.closeables.addCloseable(this.authManager); this.closeables.addCloseable(this.io); this.closeables.addCloseable(this.client); this.closeables.addCloseable(fileIOTracker); @@ -342,27 +245,6 @@ public void initialize(String name, Map unresolved) { super.initialize(name, mergedProps); } - private AuthSession session(SessionContext context) { - AuthSession session = - sessions.get( - context.sessionId(), - id -> { - Pair> newSession = - newSession(context.credentials(), context.properties(), catalogAuth); - if (null != newSession) { - return newSession.second().get(); - } - - return null; - }); - - return session != null ? session : catalogAuth; - } - - private Supplier> headers(SessionContext context) { - return session(context)::headers; - } - @Override public void setConf(Object newConf) { this.conf = newConf; @@ -384,13 +266,16 @@ public List listTables(SessionContext context, Namespace ns) { do { queryParams.put("pageToken", pageToken); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); ListTablesResponse response = - client.get( - paths.tables(ns), - queryParams, - ListTablesResponse.class, - headers(context), - ErrorHandlers.namespaceErrorHandler()); + client + .withAuthSession(contextualSession) + .get( + paths.tables(ns), + queryParams, + ListTablesResponse.class, + Map.of(), + ErrorHandlers.namespaceErrorHandler()); pageToken = response.nextPageToken(); tables.addAll(response.identifiers()); } while (pageToken != null); @@ -404,8 +289,10 @@ public boolean dropTable(SessionContext context, TableIdentifier identifier) { checkIdentifierIsValid(identifier); try { - client.delete( - paths.table(identifier), null, headers(context), ErrorHandlers.tableErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(contextualSession) + .delete(paths.table(identifier), null, Map.of(), ErrorHandlers.tableErrorHandler()); return true; } catch (NoSuchTableException e) { return false; @@ -418,12 +305,15 @@ public boolean purgeTable(SessionContext context, TableIdentifier identifier) { checkIdentifierIsValid(identifier); try { - client.delete( - paths.table(identifier), - ImmutableMap.of("purgeRequested", "true"), - null, - headers(context), - ErrorHandlers.tableErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(contextualSession) + .delete( + paths.table(identifier), + ImmutableMap.of("purgeRequested", "true"), + null, + Map.of(), + ErrorHandlers.tableErrorHandler()); return true; } catch (NoSuchTableException e) { return false; @@ -440,7 +330,10 @@ public void renameTable(SessionContext context, TableIdentifier from, TableIdent RenameTableRequest.builder().withSource(from).withDestination(to).build(); // for now, ignore the response because there is no way to return it - client.post(paths.rename(), request, null, headers(context), ErrorHandlers.tableErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(contextualSession) + .post(paths.rename(), request, null, Map.of(), ErrorHandlers.tableErrorHandler()); } @Override @@ -448,7 +341,10 @@ public boolean tableExists(SessionContext context, TableIdentifier identifier) { try { checkIdentifierIsValid(identifier); if (endpoints.contains(Endpoint.V1_TABLE_EXISTS)) { - client.head(paths.table(identifier), headers(context), ErrorHandlers.tableErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(contextualSession) + .head(paths.table(identifier), Map.of(), ErrorHandlers.tableErrorHandler()); return true; } else { // fallback in order to work with 1.7.x and older servers @@ -462,12 +358,15 @@ public boolean tableExists(SessionContext context, TableIdentifier identifier) { private LoadTableResponse loadInternal( SessionContext context, TableIdentifier identifier, SnapshotMode mode) { Endpoint.check(endpoints, Endpoint.V1_LOAD_TABLE); - return client.get( - paths.table(identifier), - mode.params(), - LoadTableResponse.class, - headers(context), - ErrorHandlers.tableErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + return client + .withAuthSession(contextualSession) + .get( + paths.table(identifier), + mode.params(), + LoadTableResponse.class, + Map.of(), + ErrorHandlers.tableErrorHandler()); } @Override @@ -509,7 +408,10 @@ public Table loadTable(SessionContext context, TableIdentifier identifier) { } TableIdentifier finalIdentifier = loadedIdent; - AuthSession session = tableSession(response.config(), session(context)); + Map tableConf = response.config(); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + AuthSession tableSession = + authManager.tableSession(finalIdentifier, tableConf, contextualSession); TableMetadata tableMetadata; if (snapshotMode == SnapshotMode.REFS) { @@ -528,11 +430,12 @@ public Table loadTable(SessionContext context, TableIdentifier identifier) { tableMetadata = response.tableMetadata(); } + RESTClient tableClient = client.withAuthSession(tableSession); RESTTableOperations ops = new RESTTableOperations( - client, + tableClient, paths.table(finalIdentifier), - session::headers, + Map::of, tableFileIO(context, response.config()), tableMetadata, endpoints); @@ -543,7 +446,7 @@ public Table loadTable(SessionContext context, TableIdentifier identifier) { new BaseTable( ops, fullTableName(finalIdentifier), - metricsReporter(paths.metrics(finalIdentifier), session::headers)); + metricsReporter(paths.metrics(finalIdentifier), tableClient)); if (metadataType != null) { return MetadataTableUtils.createMetadataTableInstance(table, metadataType); } @@ -557,11 +460,10 @@ private void trackFileIO(RESTTableOperations ops) { } } - private MetricsReporter metricsReporter( - String metricsEndpoint, Supplier> headers) { + private MetricsReporter metricsReporter(String metricsEndpoint, RESTClient restClient) { if (reportingViaRestEnabled && endpoints.contains(Endpoint.V1_REPORT_METRICS)) { RESTMetricsReporter restMetricsReporter = - new RESTMetricsReporter(client, metricsEndpoint, headers); + new RESTMetricsReporter(restClient, metricsEndpoint, Map::of); return MetricsReporters.combine(reporter, restMetricsReporter); } else { return this.reporter; @@ -594,20 +496,25 @@ public Table registerTable( .metadataLocation(metadataFileLocation) .build(); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); LoadTableResponse response = - client.post( - paths.register(ident.namespace()), - request, - LoadTableResponse.class, - headers(context), - ErrorHandlers.tableErrorHandler()); - - AuthSession session = tableSession(response.config(), session(context)); + client + .withAuthSession(contextualSession) + .post( + paths.register(ident.namespace()), + request, + LoadTableResponse.class, + Map.of(), + ErrorHandlers.tableErrorHandler()); + + Map tableConf = response.config(); + AuthSession tableSession = authManager.tableSession(ident, tableConf, contextualSession); + RESTClient tableClient = client.withAuthSession(tableSession); RESTTableOperations ops = new RESTTableOperations( - client, + tableClient, paths.table(ident), - session::headers, + Map::of, tableFileIO(context, response.config()), response.tableMetadata(), endpoints); @@ -615,7 +522,7 @@ public Table registerTable( trackFileIO(ops); return new BaseTable( - ops, fullTableName(ident), metricsReporter(paths.metrics(ident), session::headers)); + ops, fullTableName(ident), metricsReporter(paths.metrics(ident), tableClient)); } @Override @@ -626,12 +533,15 @@ public void createNamespace( CreateNamespaceRequest.builder().withNamespace(namespace).setProperties(metadata).build(); // for now, ignore the response because there is no way to return it - client.post( - paths.namespaces(), - request, - CreateNamespaceResponse.class, - headers(context), - ErrorHandlers.namespaceErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(contextualSession) + .post( + paths.namespaces(), + request, + CreateNamespaceResponse.class, + Map.of(), + ErrorHandlers.namespaceErrorHandler()); } @Override @@ -653,13 +563,16 @@ public List listNamespaces(SessionContext context, Namespace namespac do { queryParams.put("pageToken", pageToken); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); ListNamespacesResponse response = - client.get( - paths.namespaces(), - queryParams, - ListNamespacesResponse.class, - headers(context), - ErrorHandlers.namespaceErrorHandler()); + client + .withAuthSession(contextualSession) + .get( + paths.namespaces(), + queryParams, + ListNamespacesResponse.class, + Map.of(), + ErrorHandlers.namespaceErrorHandler()); pageToken = response.nextPageToken(); namespaces.addAll(response.namespaces()); } while (pageToken != null); @@ -672,8 +585,10 @@ public boolean namespaceExists(SessionContext context, Namespace namespace) { try { checkNamespaceIsValid(namespace); if (endpoints.contains(Endpoint.V1_NAMESPACE_EXISTS)) { - client.head( - paths.namespace(namespace), headers(context), ErrorHandlers.namespaceErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(contextualSession) + .head(paths.namespace(namespace), Map.of(), ErrorHandlers.namespaceErrorHandler()); return true; } else { // fallback in order to work with 1.7.x and older servers @@ -690,12 +605,15 @@ public Map loadNamespaceMetadata(SessionContext context, Namespa checkNamespaceIsValid(ns); // TODO: rename to LoadNamespaceResponse? + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); GetNamespaceResponse response = - client.get( - paths.namespace(ns), - GetNamespaceResponse.class, - headers(context), - ErrorHandlers.namespaceErrorHandler()); + client + .withAuthSession(contextualSession) + .get( + paths.namespace(ns), + GetNamespaceResponse.class, + Map.of(), + ErrorHandlers.namespaceErrorHandler()); return response.properties(); } @@ -705,8 +623,10 @@ public boolean dropNamespace(SessionContext context, Namespace ns) { checkNamespaceIsValid(ns); try { - client.delete( - paths.namespace(ns), null, headers(context), ErrorHandlers.namespaceErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(contextualSession) + .delete(paths.namespace(ns), null, Map.of(), ErrorHandlers.namespaceErrorHandler()); return true; } catch (NoSuchNamespaceException e) { return false; @@ -722,66 +642,27 @@ public boolean updateNamespaceMetadata( UpdateNamespacePropertiesRequest request = UpdateNamespacePropertiesRequest.builder().updateAll(updates).removeAll(removals).build(); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); UpdateNamespacePropertiesResponse response = - client.post( - paths.namespaceProperties(ns), - request, - UpdateNamespacePropertiesResponse.class, - headers(context), - ErrorHandlers.namespaceErrorHandler()); + client + .withAuthSession(contextualSession) + .post( + paths.namespaceProperties(ns), + request, + UpdateNamespacePropertiesResponse.class, + Map.of(), + ErrorHandlers.namespaceErrorHandler()); return !response.updated().isEmpty(); } - private ScheduledExecutorService tokenRefreshExecutor(String catalogName) { - if (!keepTokenRefreshed) { - return null; - } - - if (refreshExecutor == null) { - synchronized (this) { - if (refreshExecutor == null) { - this.refreshExecutor = ThreadPools.newScheduledPool(catalogName + "-token-refresh", 1); - } - } - } - - return refreshExecutor; - } - @Override public void close() throws IOException { - shutdownRefreshExecutor(); - if (closeables != null) { closeables.close(); } } - private void shutdownRefreshExecutor() { - if (refreshExecutor != null) { - ScheduledExecutorService service = refreshExecutor; - this.refreshExecutor = null; - - List tasks = service.shutdownNow(); - tasks.forEach( - task -> { - if (task instanceof Future) { - ((Future) task).cancel(true); - } - }); - - try { - if (!service.awaitTermination(1, TimeUnit.MINUTES)) { - LOG.warn("Timed out waiting for refresh executor to terminate"); - } - } catch (InterruptedException e) { - LOG.warn("Interrupted while waiting for refresh executor to terminate", e); - Thread.currentThread().interrupt(); - } - } - } - private class Builder implements Catalog.TableBuilder { private final TableIdentifier ident; private final Schema schema; @@ -859,20 +740,25 @@ public Table create() { .setProperties(propertiesBuilder.buildKeepingLast()) .build(); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); LoadTableResponse response = - client.post( - paths.tables(ident.namespace()), - request, - LoadTableResponse.class, - headers(context), - ErrorHandlers.tableErrorHandler()); - - AuthSession session = tableSession(response.config(), session(context)); + client + .withAuthSession(contextualSession) + .post( + paths.tables(ident.namespace()), + request, + LoadTableResponse.class, + Map.of(), + ErrorHandlers.tableErrorHandler()); + + Map tableConf = response.config(); + AuthSession tableSession = authManager.tableSession(ident, tableConf, contextualSession); + RESTClient tableClient = client.withAuthSession(tableSession); RESTTableOperations ops = new RESTTableOperations( - client, + tableClient, paths.table(ident), - session::headers, + Map::of, tableFileIO(context, response.config()), response.tableMetadata(), endpoints); @@ -880,7 +766,7 @@ public Table create() { trackFileIO(ops); return new BaseTable( - ops, fullTableName(ident), metricsReporter(paths.metrics(ident), session::headers)); + ops, fullTableName(ident), metricsReporter(paths.metrics(ident), tableClient)); } @Override @@ -889,14 +775,17 @@ public Transaction createTransaction() { LoadTableResponse response = stageCreate(); String fullName = fullTableName(ident); - AuthSession session = tableSession(response.config(), session(context)); + Map tableConf = response.config(); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + AuthSession tableSession = authManager.tableSession(ident, tableConf, contextualSession); TableMetadata meta = response.tableMetadata(); + RESTClient tableClient = client.withAuthSession(tableSession); RESTTableOperations ops = new RESTTableOperations( - client, + tableClient, paths.table(ident), - session::headers, + Map::of, tableFileIO(context, response.config()), RESTTableOperations.UpdateType.CREATE, createChanges(meta), @@ -906,7 +795,7 @@ public Transaction createTransaction() { trackFileIO(ops); return Transactions.createTableTransaction( - fullName, ops, meta, metricsReporter(paths.metrics(ident), session::headers)); + fullName, ops, meta, metricsReporter(paths.metrics(ident), tableClient)); } @Override @@ -919,7 +808,9 @@ public Transaction replaceTransaction() { LoadTableResponse response = loadInternal(context, ident, snapshotMode); String fullName = fullTableName(ident); - AuthSession session = tableSession(response.config(), session(context)); + Map tableConf = response.config(); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + AuthSession tableSession = authManager.tableSession(ident, tableConf, contextualSession); TableMetadata base = response.tableMetadata(); Map tableProperties = propertiesBuilder.buildKeepingLast(); @@ -951,11 +842,12 @@ public Transaction replaceTransaction() { changes.add(new MetadataUpdate.SetDefaultSortOrder(replacement.defaultSortOrderId())); } + RESTClient tableClient = client.withAuthSession(tableSession); RESTTableOperations ops = new RESTTableOperations( - client, + tableClient, paths.table(ident), - session::headers, + Map::of, tableFileIO(context, response.config()), RESTTableOperations.UpdateType.REPLACE, changes.build(), @@ -965,7 +857,7 @@ public Transaction replaceTransaction() { trackFileIO(ops); return Transactions.replaceTableTransaction( - fullName, ops, replacement, metricsReporter(paths.metrics(ident), session::headers)); + fullName, ops, replacement, metricsReporter(paths.metrics(ident), tableClient)); } @Override @@ -997,12 +889,15 @@ private LoadTableResponse stageCreate() { .setProperties(tableProperties) .build(); - return client.post( - paths.tables(ident.namespace()), - request, - LoadTableResponse.class, - headers(context), - ErrorHandlers.tableErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + return client + .withAuthSession(contextualSession) + .post( + paths.tables(ident.namespace()), + request, + LoadTableResponse.class, + Map.of(), + ErrorHandlers.tableErrorHandler()); } } @@ -1068,26 +963,8 @@ private FileIO tableFileIO(SessionContext context, Map config) { return newFileIO(context, fullConf); } - private AuthSession tableSession(Map tableConf, AuthSession parent) { - Map credentials = Maps.newHashMapWithExpectedSize(tableConf.size()); - for (String prop : tableConf.keySet()) { - if (TABLE_SESSION_ALLOW_LIST.contains(prop)) { - credentials.put(prop, tableConf.get(prop)); - } - } - - Pair> newSession = newSession(credentials, tableConf, parent); - if (null == newSession) { - return parent; - } - - AuthSession session = tableSessions.get(newSession.first(), id -> newSession.second().get()); - - return session != null ? session : parent; - } - private static ConfigResponse fetchConfig( - RESTClient client, Map headers, Map properties) { + RESTClient client, AuthSession initialAuth, Map properties) { // send the client's warehouse location to the service to keep in sync // this is needed for cases where the warehouse is configured client side, but may be used on // the server side, @@ -1101,76 +978,18 @@ private static ConfigResponse fetchConfig( } ConfigResponse configResponse = - client.get( - ResourcePaths.config(), - queryParams.build(), - ConfigResponse.class, - headers, - ErrorHandlers.defaultErrorHandler()); + client + .withAuthSession(initialAuth) + .get( + ResourcePaths.config(), + queryParams.build(), + ConfigResponse.class, + configHeaders(properties), + ErrorHandlers.defaultErrorHandler()); configResponse.validate(); return configResponse; } - private Pair> newSession( - Map credentials, Map properties, AuthSession parent) { - if (credentials != null) { - // use the bearer token without exchanging - if (credentials.containsKey(OAuth2Properties.TOKEN)) { - return Pair.of( - credentials.get(OAuth2Properties.TOKEN), - () -> - AuthSession.fromAccessToken( - client, - tokenRefreshExecutor(name()), - credentials.get(OAuth2Properties.TOKEN), - expiresAtMillis(properties), - parent)); - } - - if (credentials.containsKey(OAuth2Properties.CREDENTIAL)) { - // fetch a token using the client credentials flow - return Pair.of( - credentials.get(OAuth2Properties.CREDENTIAL), - () -> - AuthSession.fromCredential( - client, - tokenRefreshExecutor(name()), - credentials.get(OAuth2Properties.CREDENTIAL), - parent)); - } - - for (String tokenType : TOKEN_PREFERENCE_ORDER) { - if (credentials.containsKey(tokenType)) { - // exchange the token for an access token using the token exchange flow - return Pair.of( - credentials.get(tokenType), - () -> - AuthSession.fromTokenExchange( - client, - tokenRefreshExecutor(name()), - credentials.get(tokenType), - tokenType, - parent)); - } - } - } - - return null; - } - - private Long expiresAtMillis(Map properties) { - if (properties.containsKey(OAuth2Properties.TOKEN_EXPIRES_IN_MS)) { - long expiresInMillis = - PropertyUtil.propertyAsLong( - properties, - OAuth2Properties.TOKEN_EXPIRES_IN_MS, - OAuth2Properties.TOKEN_EXPIRES_IN_MS_DEFAULT); - return System.currentTimeMillis() + expiresInMillis; - } else { - return null; - } - } - private void checkIdentifierIsValid(TableIdentifier tableIdentifier) { if (tableIdentifier.namespace().isEmpty()) { throw new NoSuchTableException("Invalid table identifier: %s", tableIdentifier); @@ -1193,25 +1012,6 @@ private static Map configHeaders(Map properties) return RESTUtil.extractPrefixMap(properties, "header."); } - private static Cache newSessionCache(Map properties) { - long expirationIntervalMs = - PropertyUtil.propertyAsLong( - properties, - CatalogProperties.AUTH_SESSION_TIMEOUT_MS, - CatalogProperties.AUTH_SESSION_TIMEOUT_MS_DEFAULT); - - return Caffeine.newBuilder() - .expireAfterAccess(Duration.ofMillis(expirationIntervalMs)) - .removalListener( - (RemovalListener) - (id, auth, cause) -> { - if (auth != null) { - auth.stopRefreshing(); - } - }) - .build(); - } - public void commitTransaction(SessionContext context, List commits) { Endpoint.check(endpoints, Endpoint.V1_COMMIT_TRANSACTION); List tableChanges = Lists.newArrayListWithCapacity(commits.size()); @@ -1221,12 +1021,15 @@ public void commitTransaction(SessionContext context, List commits) UpdateTableRequest.create(commit.identifier(), commit.requirements(), commit.updates())); } - client.post( - paths.commitTransaction(), - new CommitTransactionRequest(tableChanges), - null, - headers(context), - ErrorHandlers.tableCommitHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(contextualSession) + .post( + paths.commitTransaction(), + new CommitTransactionRequest(tableChanges), + null, + Map.of(), + ErrorHandlers.tableCommitHandler()); } @Override @@ -1245,13 +1048,16 @@ public List listViews(SessionContext context, Namespace namespa do { queryParams.put("pageToken", pageToken); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); ListTablesResponse response = - client.get( - paths.views(namespace), - queryParams, - ListTablesResponse.class, - headers(context), - ErrorHandlers.namespaceErrorHandler()); + client + .withAuthSession(contextualSession) + .get( + paths.views(namespace), + queryParams, + ListTablesResponse.class, + Map.of(), + ErrorHandlers.namespaceErrorHandler()); pageToken = response.nextPageToken(); views.addAll(response.identifiers()); } while (pageToken != null); @@ -1264,7 +1070,10 @@ public boolean viewExists(SessionContext context, TableIdentifier identifier) { try { checkViewIdentifierIsValid(identifier); if (endpoints.contains(Endpoint.V1_VIEW_EXISTS)) { - client.head(paths.view(identifier), headers(context), ErrorHandlers.viewErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(contextualSession) + .head(paths.view(identifier), Map.of(), ErrorHandlers.viewErrorHandler()); return true; } else { // fallback in order to work with 1.7.x and older servers @@ -1287,19 +1096,27 @@ public View loadView(SessionContext context, TableIdentifier identifier) { checkViewIdentifierIsValid(identifier); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); LoadViewResponse response = - client.get( - paths.view(identifier), - LoadViewResponse.class, - headers(context), - ErrorHandlers.viewErrorHandler()); - - AuthSession session = tableSession(response.config(), session(context)); + client + .withAuthSession(contextualSession) + .get( + paths.view(identifier), + LoadViewResponse.class, + Map.of(), + ErrorHandlers.viewErrorHandler()); + + Map tableConf = response.config(); + AuthSession tableSession = authManager.tableSession(identifier, tableConf, contextualSession); ViewMetadata metadata = response.metadata(); RESTViewOperations ops = new RESTViewOperations( - client, paths.view(identifier), session::headers, metadata, endpoints); + client.withAuthSession(tableSession), + paths.view(identifier), + Map::of, + metadata, + endpoints); return new BaseView(ops, ViewUtil.fullViewName(name(), identifier)); } @@ -1315,8 +1132,10 @@ public boolean dropView(SessionContext context, TableIdentifier identifier) { checkViewIdentifierIsValid(identifier); try { - client.delete( - paths.view(identifier), null, headers(context), ErrorHandlers.viewErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(contextualSession) + .delete(paths.view(identifier), null, Map.of(), ErrorHandlers.viewErrorHandler()); return true; } catch (NoSuchViewException e) { return false; @@ -1332,8 +1151,10 @@ public void renameView(SessionContext context, TableIdentifier from, TableIdenti RenameTableRequest request = RenameTableRequest.builder().withSource(from).withDestination(to).build(); - client.post( - paths.renameView(), request, null, headers(context), ErrorHandlers.viewErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(contextualSession) + .post(paths.renameView(), request, null, Map.of(), ErrorHandlers.viewErrorHandler()); } private class RESTViewBuilder implements ViewBuilder { @@ -1439,18 +1260,26 @@ public View create() { .properties(properties) .build(); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); LoadViewResponse response = - client.post( - paths.views(identifier.namespace()), - request, - LoadViewResponse.class, - headers(context), - ErrorHandlers.viewErrorHandler()); - - AuthSession session = tableSession(response.config(), session(context)); + client + .withAuthSession(contextualSession) + .post( + paths.views(identifier.namespace()), + request, + LoadViewResponse.class, + Map.of(), + ErrorHandlers.viewErrorHandler()); + + Map tableConf = response.config(); + AuthSession tableSession = authManager.tableSession(identifier, tableConf, contextualSession); RESTViewOperations ops = new RESTViewOperations( - client, paths.view(identifier), session::headers, response.metadata(), endpoints); + client.withAuthSession(tableSession), + paths.view(identifier), + Map::of, + response.metadata(), + endpoints); return new BaseView(ops, ViewUtil.fullViewName(name(), identifier)); } @@ -1482,11 +1311,14 @@ private LoadViewResponse loadView() { "Unable to load view %s.%s: Server does not support endpoint %s", name(), identifier, Endpoint.V1_LOAD_VIEW)); - return client.get( - paths.view(identifier), - LoadViewResponse.class, - headers(context), - ErrorHandlers.viewErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + return client + .withAuthSession(contextualSession) + .get( + paths.view(identifier), + LoadViewResponse.class, + Map.of(), + ErrorHandlers.viewErrorHandler()); } private View replace(LoadViewResponse response) { @@ -1527,10 +1359,16 @@ private View replace(LoadViewResponse response) { ViewMetadata replacement = builder.build(); - AuthSession session = tableSession(response.config(), session(context)); + Map tableConf = response.config(); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + AuthSession tableSession = authManager.tableSession(identifier, tableConf, contextualSession); RESTViewOperations ops = new RESTViewOperations( - client, paths.view(identifier), session::headers, metadata, endpoints); + client.withAuthSession(tableSession), + paths.view(identifier), + Map::of, + metadata, + endpoints); ops.commit(metadata, replacement); diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthManager.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthManager.java index 8f6f16f925e3..ab79051fe14f 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/AuthManager.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthManager.java @@ -62,6 +62,16 @@ default AuthSession initSession(RESTClient initClient, Map prope */ AuthSession catalogSession(RESTClient sharedClient, Map properties); + /** + * Returns a new session targeting a table or view. This method is intended for components other + * that the catalog that need to access tables or views, such as request signer clients. + * + *

This method cannot return null. + */ + default AuthSession tableSession(RESTClient sharedClient, Map properties) { + return catalogSession(sharedClient, properties); + } + /** * Returns a session for a specific context. * diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java index df9018787580..bd18d142dade 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java @@ -58,7 +58,7 @@ public class OAuth2Manager extends RefreshingAuthManager { private final String name; - private RESTClient client; + private RESTClient refreshClient; private long startTimeMillis; private OAuthTokenResponse authResponse; private AuthSessionCache sessionCache; @@ -71,7 +71,11 @@ public OAuth2Manager(String managerName) { @Override public OAuth2Util.AuthSession initSession(RESTClient initClient, Map properties) { warnIfDeprecatedTokenEndpointUsed(properties); - AuthConfig config = AuthConfig.fromProperties(properties); + AuthConfig config = + ImmutableAuthConfig.builder() + .from(AuthConfig.fromProperties(properties)) + .keepRefreshed(false) // no token refresh during init + .build(); Map headers = OAuth2Util.authHeaders(config.token()); OAuth2Util.AuthSession session = new OAuth2Util.AuthSession(headers, config); if (config.credential() != null && !config.credential().isEmpty()) { @@ -81,8 +85,8 @@ public OAuth2Util.AuthSession initSession(RESTClient initClient, Map properties) { - this.client = sharedClient; + // This client will be used for token refreshes; it should not have an auth session. + this.refreshClient = sharedClient.withAuthSession(AuthSession.EMPTY); this.sessionCache = newSessionCache(name, properties); AuthConfig config = AuthConfig.fromProperties(properties); Map headers = OAuth2Util.authHeaders(config.token()); @@ -109,21 +114,22 @@ public OAuth2Util.AuthSession catalogSession( // so reuse it now and turn token refresh on. if (authResponse != null) { return OAuth2Util.AuthSession.fromTokenResponse( - client, refreshExecutor(), authResponse, startTimeMillis, session); + refreshClient, refreshExecutor(), authResponse, startTimeMillis, session); + } else if (config.token() != null) { + // If both a token and a credential are provided, prefer the token. + return OAuth2Util.AuthSession.fromAccessToken( + refreshClient, refreshExecutor(), config.token(), config.expiresAtMillis(), session); } else if (config.credential() != null && !config.credential().isEmpty()) { OAuthTokenResponse response = OAuth2Util.fetchToken( - sharedClient, - headers, + sharedClient.withAuthSession(session), + Map.of(), config.credential(), config.scope(), config.oauth2ServerUri(), config.optionalOAuthParams()); return OAuth2Util.AuthSession.fromTokenResponse( - sharedClient, refreshExecutor(), response, System.currentTimeMillis(), session); - } else if (config.token() != null) { - return OAuth2Util.AuthSession.fromAccessToken( - client, refreshExecutor(), config.token(), config.expiresAtMillis(), session); + refreshClient, refreshExecutor(), response, System.currentTimeMillis(), session); } return session; } @@ -205,18 +211,19 @@ protected OAuth2Util.AuthSession newSessionFromAccessToken( String token, Map properties, OAuth2Util.AuthSession parent) { Long expiresAtMillis = AuthConfig.fromProperties(properties).expiresAtMillis(); return OAuth2Util.AuthSession.fromAccessToken( - client, refreshExecutor(), token, expiresAtMillis, parent); + refreshClient, refreshExecutor(), token, expiresAtMillis, parent); } protected OAuth2Util.AuthSession newSessionFromCredential( String credential, OAuth2Util.AuthSession parent) { - return OAuth2Util.AuthSession.fromCredential(client, refreshExecutor(), credential, parent); + return OAuth2Util.AuthSession.fromCredential( + refreshClient, refreshExecutor(), credential, parent); } protected OAuth2Util.AuthSession newSessionFromTokenExchange( String token, String tokenType, OAuth2Util.AuthSession parent) { return OAuth2Util.AuthSession.fromTokenExchange( - client, refreshExecutor(), token, tokenType, parent); + refreshClient, refreshExecutor(), token, tokenType, parent); } private static void warnIfDeprecatedTokenEndpointUsed(Map properties) { diff --git a/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java b/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java index 13d555598a3b..af9b4045f8c2 100644 --- a/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java +++ b/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java @@ -44,12 +44,8 @@ import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManager; import org.apache.hc.client5.http.io.HttpClientConnectionManager; -import org.apache.hc.core5.http.EntityDetails; -import org.apache.hc.core5.http.HttpException; import org.apache.hc.core5.http.HttpHost; -import org.apache.hc.core5.http.HttpRequestInterceptor; import org.apache.hc.core5.http.HttpStatus; -import org.apache.hc.core5.http.protocol.HttpContext; import org.apache.iceberg.IcebergBuild; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.rest.auth.AuthSession; @@ -228,18 +224,6 @@ public void accept(ErrorResponse errorResponse) { } } - @Test - public void testDynamicHttpRequestInterceptorLoading() { - Map properties = ImmutableMap.of("key", "val"); - - HttpRequestInterceptor interceptor = - HTTPClient.loadInterceptorDynamically( - TestHttpRequestInterceptor.class.getName(), properties); - - assertThat(interceptor).isInstanceOf(TestHttpRequestInterceptor.class); - assertThat(((TestHttpRequestInterceptor) interceptor).properties).isEqualTo(properties); - } - @Test public void testSocketAndConnectionTimeoutSet() { long connectionTimeoutMs = 10L; @@ -484,17 +468,4 @@ public boolean equals(Object o) { return Objects.equals(id, item.id) && Objects.equals(data, item.data); } } - - public static class TestHttpRequestInterceptor implements HttpRequestInterceptor { - private Map properties; - - public void initialize(Map props) { - this.properties = props; - } - - @Override - public void process( - org.apache.hc.core5.http.HttpRequest request, EntityDetails entity, HttpContext context) - throws HttpException, IOException {} - } } diff --git a/core/src/test/java/org/apache/iceberg/rest/auth/TestOAuth2Manager.java b/core/src/test/java/org/apache/iceberg/rest/auth/TestOAuth2Manager.java index c9b6e74b78f5..5d9b023aa684 100644 --- a/core/src/test/java/org/apache/iceberg/rest/auth/TestOAuth2Manager.java +++ b/core/src/test/java/org/apache/iceberg/rest/auth/TestOAuth2Manager.java @@ -45,6 +45,7 @@ class TestOAuth2Manager { @BeforeEach void before() { client = Mockito.mock(RESTClient.class); + when(client.withAuthSession(any())).thenReturn(client); when(client.postForm(any(), any(), eq(OAuthTokenResponse.class), anyMap(), any())) .thenReturn( OAuthTokenResponse.builder() @@ -107,6 +108,8 @@ void initSessionCredentialsProvided() { eq(OAuthTokenResponse.class), eq(Map.of()), any()); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -120,7 +123,8 @@ void catalogSessionNoOAuth2Properties() { .as("should not create refresh executor when no token and no credentials provided") .isNull(); } - Mockito.verifyNoInteractions(client); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -134,7 +138,8 @@ void catalogSessionTokenProvided() { .as("should create refresh executor when token provided") .isNotNull(); } - Mockito.verifyNoInteractions(client); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -149,7 +154,8 @@ void catalogSessionRefreshDisabled() { .as("should not create refresh executor when refresh disabled") .isNull(); } - Mockito.verifyNoInteractions(client); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -178,6 +184,8 @@ void catalogSessionCredentialsProvidedWithInitSession() { eq(OAuthTokenResponse.class), eq(Map.of()), any()); + Mockito.verify(client, times(2)).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -226,7 +234,8 @@ void contextualSessionEmptyContext() { .as("should not create session cache for empty context") .satisfies(cache -> assertThat(cache.sessionCache().asMap()).isEmpty()); } - Mockito.verifyNoInteractions(client); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -252,7 +261,8 @@ void contextualSessionTokenProvided() { .as("should create session cache for context with token") .satisfies(cache -> assertThat(cache.sessionCache().asMap()).hasSize(1)); } - Mockito.verifyNoInteractions(client); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -289,6 +299,8 @@ void contextualSessionCredentialsProvided() { eq(OAuthTokenResponse.class), eq(Map.of()), any()); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -326,6 +338,8 @@ void contextualSessionTokenExchange() { eq(OAuthTokenResponse.class), eq(Map.of("Authorization", "Bearer catalog-token")), any()); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -351,7 +365,8 @@ void contextualSessionCacheHit() { Mockito.verify(manager, times(1)) .newSessionFromAccessToken("context-token", Map.of(), catalogSession); } - Mockito.verifyNoInteractions(client); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -373,7 +388,8 @@ void tableSessionNoTableCredentials() { .as("should not create session cache for empty table credentials") .satisfies(cache -> assertThat(cache.sessionCache().asMap()).isEmpty()); } - Mockito.verifyNoInteractions(client); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -397,7 +413,8 @@ void tableSessionTokenProvided() { .as("should create session cache for table with token") .satisfies(cache -> assertThat(cache.sessionCache().asMap()).hasSize(1)); } - Mockito.verifyNoInteractions(client); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -434,6 +451,8 @@ void tableSessionTokenExchange() { eq(OAuthTokenResponse.class), eq(Map.of("Authorization", "Bearer catalog-token")), any()); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -458,7 +477,8 @@ void tableSessionCacheHit() { Mockito.verify(manager, times(1)) .newSessionFromAccessToken("table-token", Map.of("token", "table-token"), catalogSession); } - Mockito.verifyNoInteractions(client); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -483,7 +503,8 @@ void tableSessionDisallowedTableProperties() { .as("should not create session cache for ignored table credentials") .satisfies(cache -> assertThat(cache.sessionCache().asMap()).isEmpty()); } - Mockito.verifyNoInteractions(client); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } @Test @@ -524,6 +545,7 @@ protected OAuth2Util.AuthSession newSessionFromAccessToken( Mockito.verify(contextualSession).close(); Mockito.verify(tableSession).close(); } - Mockito.verifyNoInteractions(client); + Mockito.verify(client).withAuthSession(any()); + Mockito.verifyNoMoreInteractions(client); } } diff --git a/gcp/src/main/java/org/apache/iceberg/gcp/gcs/GCSFileIO.java b/gcp/src/main/java/org/apache/iceberg/gcp/gcs/GCSFileIO.java index 5737606aef5e..b41be9c8f419 100644 --- a/gcp/src/main/java/org/apache/iceberg/gcp/gcs/GCSFileIO.java +++ b/gcp/src/main/java/org/apache/iceberg/gcp/gcs/GCSFileIO.java @@ -67,6 +67,7 @@ public class GCSFileIO implements DelegateFileIO { private MetricsContext metrics = MetricsContext.nullMetrics(); private final AtomicBoolean isResourceClosed = new AtomicBoolean(false); private SerializableMap properties = null; + private OAuth2RefreshCredentialsHandler refreshHandler = null; /** * No-arg constructor to load the FileIO dynamically. @@ -159,10 +160,11 @@ public void initialize(Map props) { new AccessToken(token, gcpProperties.oauth2TokenExpiresAt().orElse(null)); if (gcpProperties.oauth2RefreshCredentialsEnabled() && gcpProperties.oauth2RefreshCredentialsEndpoint().isPresent()) { + refreshHandler = OAuth2RefreshCredentialsHandler.create(properties); builder.setCredentials( OAuth2CredentialsWithRefresh.newBuilder() .setAccessToken(accessToken) - .setRefreshHandler(OAuth2RefreshCredentialsHandler.create(properties)) + .setRefreshHandler(refreshHandler) .build()); } else { builder.setCredentials(OAuth2Credentials.create(accessToken)); @@ -196,6 +198,9 @@ private void initMetrics(Map props) { public void close() { // handles concurrent calls to close() if (isResourceClosed.compareAndSet(false, true)) { + if (refreshHandler != null) { + refreshHandler.close(); + } if (storage != null) { // GCS Storage does not appear to be closable, so release the reference storage = null; diff --git a/gcp/src/main/java/org/apache/iceberg/gcp/gcs/OAuth2RefreshCredentialsHandler.java b/gcp/src/main/java/org/apache/iceberg/gcp/gcs/OAuth2RefreshCredentialsHandler.java index 6f7807e8dd6e..e350cc5af8f4 100644 --- a/gcp/src/main/java/org/apache/iceberg/gcp/gcs/OAuth2RefreshCredentialsHandler.java +++ b/gcp/src/main/java/org/apache/iceberg/gcp/gcs/OAuth2RefreshCredentialsHandler.java @@ -21,53 +21,48 @@ import com.google.auth.oauth2.AccessToken; import com.google.auth.oauth2.OAuth2CredentialsWithRefresh; import java.io.IOException; +import java.io.UncheckedIOException; import java.util.Date; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import org.apache.iceberg.gcp.GCPProperties; +import org.apache.iceberg.io.CloseableGroup; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.rest.ErrorHandlers; import org.apache.iceberg.rest.HTTPClient; -import org.apache.iceberg.rest.HTTPHeaders; import org.apache.iceberg.rest.RESTClient; +import org.apache.iceberg.rest.auth.AuthManager; +import org.apache.iceberg.rest.auth.AuthManagers; import org.apache.iceberg.rest.auth.AuthSession; -import org.apache.iceberg.rest.auth.DefaultAuthSession; -import org.apache.iceberg.rest.auth.OAuth2Properties; -import org.apache.iceberg.rest.auth.OAuth2Util; import org.apache.iceberg.rest.credentials.Credential; import org.apache.iceberg.rest.responses.LoadCredentialsResponse; public class OAuth2RefreshCredentialsHandler - implements OAuth2CredentialsWithRefresh.OAuth2RefreshHandler { + implements OAuth2CredentialsWithRefresh.OAuth2RefreshHandler, AutoCloseable { private final Map properties; - private final AuthSession authSession; + private volatile HTTPClient client; + private AuthManager authManager; + private AuthSession authSession; private OAuth2RefreshCredentialsHandler(Map properties) { Preconditions.checkArgument( null != properties.get(GCPProperties.GCS_OAUTH2_REFRESH_CREDENTIALS_ENDPOINT), "Invalid credentials endpoint: null"); this.properties = properties; - this.authSession = - DefaultAuthSession.of( - HTTPHeaders.of(OAuth2Util.authHeaders(properties.get(OAuth2Properties.TOKEN)))); } @SuppressWarnings("JavaUtilDate") // GCP API uses java.util.Date @Override public AccessToken refreshAccessToken() { - LoadCredentialsResponse response; - try (RESTClient client = httpClient()) { - response = - client.get( - properties.get(GCPProperties.GCS_OAUTH2_REFRESH_CREDENTIALS_ENDPOINT), - null, - LoadCredentialsResponse.class, - Map.of(), - ErrorHandlers.defaultErrorHandler()); - } catch (IOException e) { - throw new RuntimeException(e); - } + LoadCredentialsResponse response = + httpClient() + .get( + properties.get(GCPProperties.GCS_OAUTH2_REFRESH_CREDENTIALS_ENDPOINT), + null, + LoadCredentialsResponse.class, + Map.of(), + ErrorHandlers.defaultErrorHandler()); List gcsCredentials = response.credentials().stream() @@ -100,9 +95,34 @@ public static OAuth2RefreshCredentialsHandler create(Map propert } private RESTClient httpClient() { - return HTTPClient.builder(properties) - .uri(properties.get(GCPProperties.GCS_OAUTH2_REFRESH_CREDENTIALS_ENDPOINT)) - .withAuthSession(authSession) - .build(); + if (null == client) { + synchronized (this) { + if (null == client) { + authManager = AuthManagers.loadAuthManager("gcs-credentials-refresh", properties); + HTTPClient httpClient = + HTTPClient.builder(properties) + .uri(properties.get(GCPProperties.GCS_OAUTH2_REFRESH_CREDENTIALS_ENDPOINT)) + .build(); + authSession = authManager.catalogSession(httpClient, properties); + client = httpClient.withAuthSession(authSession); + } + } + } + + return client; + } + + @Override + public void close() { + CloseableGroup closeableGroup = new CloseableGroup(); + closeableGroup.addCloseable(authSession); + closeableGroup.addCloseable(authManager); + closeableGroup.addCloseable(client); + closeableGroup.setSuppressCloseFailure(true); + try { + closeableGroup.close(); + } catch (IOException e) { + throw new UncheckedIOException("Failed to close the OAuth2RefreshCredentialsHandler", e); + } } }