diff --git a/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4AuthManager.java b/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4AuthManager.java new file mode 100644 index 000000000000..cb238326f552 --- /dev/null +++ b/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4AuthManager.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.aws; + +import java.util.Map; +import org.apache.iceberg.rest.RESTClient; +import org.apache.iceberg.rest.auth.OAuth2Manager; +import org.apache.iceberg.rest.auth.OAuth2Util; + +/** + * An AuthManager that authenticates requests with SigV4. + * + *

It extends {@link OAuth2Manager} to handle OAuth2 authentication as well. In case of + * conflicting headers, the OAuth2 Authorization header will be relocated, then included in the + * canonical headers to sign. + * + *

See Signing AWS + * API requests for details about the SigV4 protocol. + */ +@SuppressWarnings("unused") // loaded by reflection +public class RESTSigV4AuthManager extends OAuth2Manager { + + private RESTSigV4Signer signer; + + public RESTSigV4AuthManager(String name) { + super(name); + } + + @Override + public RESTSigv4AuthSession initSession(RESTClient initClient, Map properties) { + RESTSigV4Signer initSigner = new RESTSigV4Signer(properties); + return new RESTSigv4AuthSession(super.initSession(initClient, properties), initSigner); + } + + @Override + public RESTSigv4AuthSession catalogSession( + RESTClient sharedClient, Map properties) { + signer = new RESTSigV4Signer(properties); + return new RESTSigv4AuthSession(super.catalogSession(sharedClient, properties), signer); + } + + @Override + protected RESTSigv4AuthSession newSessionFromAccessToken( + String token, Map properties, OAuth2Util.AuthSession parent) { + return new RESTSigv4AuthSession( + super.newSessionFromAccessToken(token, properties, parent), signer); + } + + @Override + protected RESTSigv4AuthSession newSessionFromCredential( + String credential, OAuth2Util.AuthSession parent) { + return new RESTSigv4AuthSession(super.newSessionFromCredential(credential, parent), signer); + } + + @Override + protected RESTSigv4AuthSession newSessionFromTokenExchange( + String token, String tokenType, OAuth2Util.AuthSession parent) { + return new RESTSigv4AuthSession( + super.newSessionFromTokenExchange(token, tokenType, parent), signer); + } +} diff --git a/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4Signer.java b/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4Signer.java index c1c858416352..a4021ea7dbcb 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4Signer.java +++ b/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4Signer.java @@ -18,22 +18,16 @@ */ package org.apache.iceberg.aws; -import java.io.IOException; -import java.io.UncheckedIOException; import java.net.URI; -import java.net.URISyntaxException; -import java.util.Arrays; +import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; -import org.apache.hc.core5.http.EntityDetails; -import org.apache.hc.core5.http.Header; +import org.apache.commons.io.IOUtils; import org.apache.hc.core5.http.HttpHeaders; -import org.apache.hc.core5.http.HttpRequest; -import org.apache.hc.core5.http.HttpRequestInterceptor; -import org.apache.hc.core5.http.io.entity.StringEntity; -import org.apache.hc.core5.http.protocol.HttpContext; -import org.apache.iceberg.exceptions.RESTException; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.rest.HTTPRequest; +import org.apache.iceberg.rest.ImmutableHTTPRequest; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.signer.Aws4Signer; import software.amazon.awssdk.auth.signer.internal.SignerConstant; @@ -45,25 +39,24 @@ import software.amazon.awssdk.regions.Region; /** - * Provides a request interceptor for use with the HTTPClient that calculates the required signature - * for the SigV4 protocol and adds the necessary headers for all requests created by the client. + * A SigV4 signer that calculates the required signature for the SigV4 protocol and adds the + * necessary headers for all requests created by the client. * - *

See Signing AWS + *

See Signing AWS * API requests for details about the protocol. */ -public class RESTSigV4Signer implements HttpRequestInterceptor { +public class RESTSigV4Signer { static final String EMPTY_BODY_SHA256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; static final String RELOCATED_HEADER_PREFIX = "Original-"; private final Aws4Signer signer = Aws4Signer.create(); - private AwsCredentialsProvider credentialsProvider; + private final AwsCredentialsProvider credentialsProvider; - private String signingName; - private Region signingRegion; + private final String signingName; + private final Region signingRegion; - public void initialize(Map properties) { + public RESTSigV4Signer(Map properties) { AwsProperties awsProperties = new AwsProperties(properties); this.signingRegion = awsProperties.restSigningRegion(); @@ -71,16 +64,7 @@ public void initialize(Map properties) { this.credentialsProvider = awsProperties.restCredentialsProvider(); } - @Override - public void process(HttpRequest request, EntityDetails entity, HttpContext context) { - URI requestUri; - - try { - requestUri = request.getUri(); - } catch (URISyntaxException e) { - throw new RESTException(e, "Invalid uri for request: %s", request); - } - + public HTTPRequest sign(HTTPRequest request) { Aws4SignerParams params = Aws4SignerParams.builder() .signingName(signingName) @@ -96,62 +80,79 @@ public void process(HttpRequest request, EntityDetails entity, HttpContext conte SdkHttpFullRequest.Builder sdkRequestBuilder = SdkHttpFullRequest.builder(); + URI uri = request.requestUri(); sdkRequestBuilder - .method(SdkHttpMethod.fromValue(request.getMethod())) - .protocol(request.getScheme()) - .uri(requestUri) - .headers(convertHeaders(request.getHeaders())); + .method(SdkHttpMethod.fromValue(request.method().name())) + .protocol(uri.getScheme()) + .uri(uri) + .headers(convertHeaders(request.headers())); - if (entity == null) { + String body = request.encodedBody(); + if (body == null) { // This is a workaround for the signer implementation incorrectly producing // an invalid content checksum for empty body requests. sdkRequestBuilder.putHeader(SignerConstant.X_AMZ_CONTENT_SHA256, EMPTY_BODY_SHA256); - } else if (entity instanceof StringEntity) { - sdkRequestBuilder.contentStreamProvider( - () -> { - try { - return ((StringEntity) entity).getContent(); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - }); } else { - throw new UnsupportedOperationException("Unsupported entity type: " + entity.getClass()); + sdkRequestBuilder.contentStreamProvider( + () -> IOUtils.toInputStream(body, StandardCharsets.UTF_8)); } SdkHttpFullRequest signedSdkRequest = signer.sign(sdkRequestBuilder.build(), params); - updateRequestHeaders(request, signedSdkRequest.headers()); + Map> newHeaders = + updateRequestHeaders(request, signedSdkRequest.headers()); + return ImmutableHTTPRequest.builder().from(request).headers(newHeaders).build(); } - private Map> convertHeaders(Header[] headers) { - return Arrays.stream(headers) - .collect( - Collectors.groupingBy( - // Relocate Authorization header as SigV4 takes precedence - header -> - HttpHeaders.AUTHORIZATION.equals(header.getName()) - ? RELOCATED_HEADER_PREFIX + header.getName() - : header.getName(), - Collectors.mapping(Header::getValue, Collectors.toList()))); - } - - private void updateRequestHeaders(HttpRequest request, Map> headers) { + private Map> convertHeaders(Map> headers) { + Map> converted = Maps.newHashMap(); headers.forEach( (name, values) -> { + if (name.equals(HttpHeaders.AUTHORIZATION)) { + converted.merge( + RELOCATED_HEADER_PREFIX + name, + values, + (v1, v2) -> { + List merged = Lists.newArrayList(v1); + merged.addAll(v2); + return List.copyOf(merged); + }); + } else { + converted.put(name, values); + } + }); + return converted; + } + + private Map> updateRequestHeaders( + HTTPRequest request, Map> signedHeaders) { + Map> newHeaders = Maps.newLinkedHashMap(); + newHeaders.putAll(request.headers()); + signedHeaders.forEach( + (name, signedValues) -> { if (request.containsHeader(name)) { - Header[] original = request.getHeaders(name); - request.removeHeaders(name); - Arrays.asList(original) - .forEach( - header -> { - // Relocate headers if there is a conflict with signed headers - if (!values.contains(header.getValue())) { - request.addHeader(RELOCATED_HEADER_PREFIX + name, header.getValue()); - } - }); + List originalValues = request.headers(name); + newHeaders.remove(name); + originalValues.forEach( + originalValue -> { + // Relocate headers if there is a conflict with signed headers + if (!signedValues.contains(originalValue)) { + newHeaders.compute( + RELOCATED_HEADER_PREFIX + name, + (k, v) -> { + if (v == null) { + return List.of(originalValue); + } else { + List merged = Lists.newArrayList(v); + merged.add(originalValue); + return List.copyOf(merged); + } + }); + } + }); } - values.forEach(value -> request.setHeader(name, value)); + newHeaders.put(name, signedValues); }); + return newHeaders; } } diff --git a/aws/src/main/java/org/apache/iceberg/aws/RESTSigv4AuthSession.java b/aws/src/main/java/org/apache/iceberg/aws/RESTSigv4AuthSession.java new file mode 100644 index 000000000000..1947854a6930 --- /dev/null +++ b/aws/src/main/java/org/apache/iceberg/aws/RESTSigv4AuthSession.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.aws; + +import org.apache.iceberg.rest.HTTPRequest; +import org.apache.iceberg.rest.auth.OAuth2Util; + +/** + * An AuthSession that signs requests with SigV4. + * + *

It extends {@link OAuth2Util.AuthSession} to handle OAuth2 authentication as well. + */ +public class RESTSigv4AuthSession extends OAuth2Util.AuthSession { + + private final RESTSigV4Signer signer; + + public RESTSigv4AuthSession(OAuth2Util.AuthSession authSession, RESTSigV4Signer signer) { + super(authSession.headers(), authSession.config()); + this.signer = signer; + } + + @Override + public HTTPRequest authenticate(HTTPRequest request) { + return signer.sign(super.authenticate(request)); + } +} diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java b/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java index 806c52420f89..ae3154cc2637 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java @@ -20,7 +20,6 @@ import com.github.benmanes.caffeine.cache.Cache; import com.github.benmanes.caffeine.cache.Caffeine; -import com.github.benmanes.caffeine.cache.RemovalListener; import java.io.IOException; import java.io.InputStream; import java.net.URI; @@ -28,13 +27,11 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Supplier; import javax.annotation.Nullable; import org.apache.iceberg.CatalogProperties; -import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Maps; @@ -42,13 +39,13 @@ import org.apache.iceberg.rest.HTTPClient; import org.apache.iceberg.rest.RESTClient; import org.apache.iceberg.rest.ResourcePaths; -import org.apache.iceberg.rest.auth.AuthConfig; +import org.apache.iceberg.rest.auth.AuthManager; +import org.apache.iceberg.rest.auth.AuthManagers; +import org.apache.iceberg.rest.auth.AuthSession; +import org.apache.iceberg.rest.auth.AuthSessionCache; import org.apache.iceberg.rest.auth.OAuth2Properties; import org.apache.iceberg.rest.auth.OAuth2Util; -import org.apache.iceberg.rest.auth.OAuth2Util.AuthSession; -import org.apache.iceberg.rest.responses.OAuthTokenResponse; import org.apache.iceberg.util.PropertyUtil; -import org.apache.iceberg.util.ThreadPools; import org.immutables.value.Value; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -81,13 +78,13 @@ public abstract class S3V4RestSignerClient private static final String SCOPE = "sign"; @SuppressWarnings("immutables:incompat") - private static volatile ScheduledExecutorService tokenRefreshExecutor; + private static volatile AuthManager authManager; @SuppressWarnings("immutables:incompat") private static volatile RESTClient httpClient; @SuppressWarnings("immutables:incompat") - private static volatile Cache authSessionCache; + private static volatile AuthSessionCache authSessionCache; public abstract Map properties(); @@ -138,24 +135,19 @@ boolean keepTokenRefreshed() { OAuth2Properties.TOKEN_REFRESH_ENABLED_DEFAULT); } - @VisibleForTesting - ScheduledExecutorService tokenRefreshExecutor() { - if (!keepTokenRefreshed()) { - return null; - } - - if (null == tokenRefreshExecutor) { + private AuthManager authManager() { + if (null == authManager) { synchronized (S3V4RestSignerClient.class) { - if (null == tokenRefreshExecutor) { - tokenRefreshExecutor = ThreadPools.newScheduledPool("s3-signer-token-refresh", 1); + if (null == authManager) { + authManager = AuthManagers.loadAuthManager("s3-signer", properties()); } } } - return tokenRefreshExecutor; + return authManager; } - private Cache authSessionCache() { + private AuthSessionCache authSessionCache() { if (null == authSessionCache) { synchronized (S3V4RestSignerClient.class) { if (null == authSessionCache) { @@ -165,18 +157,7 @@ private Cache authSessionCache() { CatalogProperties.AUTH_SESSION_TIMEOUT_MS, CatalogProperties.AUTH_SESSION_TIMEOUT_MS_DEFAULT); - authSessionCache = - Caffeine.newBuilder() - .expireAfterAccess(Duration.ofMillis(expirationIntervalMs)) - .removalListener( - (RemovalListener) - (id, auth, cause) -> { - if (null != auth) { - LOG.trace("Stopping refresh for AuthSession"); - auth.stopRefreshing(); - } - }) - .build(); + authSessionCache = new AuthSessionCache(Duration.ofMillis(expirationIntervalMs)); } } } @@ -204,73 +185,51 @@ private AuthSession authSession() { String token = token().get(); if (null != token) { return authSessionCache() - .get( + .cachedSession( token, - id -> - AuthSession.fromAccessToken( - httpClient(), - tokenRefreshExecutor(), - token, - expiresAtMillis(properties()), - new AuthSession( - ImmutableMap.of(), - AuthConfig.builder() - .token(token) - .credential(credential()) - .scope(SCOPE) - .oauth2ServerUri(oauth2ServerUri()) - .optionalOAuthParams(optionalOAuthParams()) - .build()))); + () -> { + Map properties = + ImmutableMap.builder() + .putAll(properties()) + .putAll(optionalOAuthParams()) + .put(OAuth2Properties.OAUTH2_SERVER_URI, oauth2ServerUri()) + .put( + OAuth2Properties.TOKEN_REFRESH_ENABLED, + String.valueOf(keepTokenRefreshed())) + .put(OAuth2Properties.TOKEN, token) + .put(OAuth2Properties.SCOPE, SCOPE) + .buildKeepingLast(); + return authManager().catalogSession(httpClient(), properties); + }); } if (credentialProvided()) { return authSessionCache() - .get( + .cachedSession( credential(), - id -> { - AuthSession session = - new AuthSession( - ImmutableMap.of(), - AuthConfig.builder() - .credential(credential()) - .scope(SCOPE) - .oauth2ServerUri(oauth2ServerUri()) - .optionalOAuthParams(optionalOAuthParams()) - .build()); - long startTimeMillis = System.currentTimeMillis(); - OAuthTokenResponse authResponse = - OAuth2Util.fetchToken( - httpClient(), - session.headers(), - credential(), - SCOPE, - oauth2ServerUri(), - optionalOAuthParams()); - return AuthSession.fromTokenResponse( - httpClient(), tokenRefreshExecutor(), authResponse, startTimeMillis, session); + () -> { + Map properties = + ImmutableMap.builder() + .putAll(properties()) + .putAll(optionalOAuthParams()) + .put(OAuth2Properties.OAUTH2_SERVER_URI, oauth2ServerUri()) + .put( + OAuth2Properties.TOKEN_REFRESH_ENABLED, + String.valueOf(keepTokenRefreshed())) + .put(OAuth2Properties.CREDENTIAL, credential()) + .put(OAuth2Properties.SCOPE, SCOPE) + .buildKeepingLast(); + return authManager().catalogSession(httpClient(), properties); }); } - return AuthSession.empty(); + return AuthSession.EMPTY; } private boolean credentialProvided() { return null != credential() && !credential().isEmpty(); } - private Long expiresAtMillis(Map properties) { - if (properties.containsKey(OAuth2Properties.TOKEN_EXPIRES_IN_MS)) { - long expiresInMillis = - PropertyUtil.propertyAsLong( - properties, - OAuth2Properties.TOKEN_EXPIRES_IN_MS, - OAuth2Properties.TOKEN_EXPIRES_IN_MS_DEFAULT); - return System.currentTimeMillis() + expiresInMillis; - } else { - return null; - } - } - @Value.Check protected void check() { Preconditions.checkArgument( @@ -338,11 +297,12 @@ public SdkHttpFullRequest sign( Consumer> responseHeadersConsumer = responseHeaders::putAll; S3SignResponse s3SignResponse = httpClient() + .withAuthSession(authSession()) .post( endpoint(), remoteSigningRequest, S3SignResponse.class, - () -> authSession().headers(), + ImmutableMap.of(), ErrorHandlers.defaultErrorHandler(), responseHeadersConsumer); diff --git a/aws/src/test/java/org/apache/iceberg/aws/TestRESTSigV4Signer.java b/aws/src/test/java/org/apache/iceberg/aws/TestRESTSigV4Signer.java index 88623edd9334..75d05d6b4d85 100644 --- a/aws/src/test/java/org/apache/iceberg/aws/TestRESTSigV4Signer.java +++ b/aws/src/test/java/org/apache/iceberg/aws/TestRESTSigV4Signer.java @@ -27,6 +27,11 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.rest.HTTPClient; +import org.apache.iceberg.rest.RESTClient; +import org.apache.iceberg.rest.auth.AuthManager; +import org.apache.iceberg.rest.auth.AuthManagers; +import org.apache.iceberg.rest.auth.AuthProperties; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.auth.OAuth2Util; import org.apache.iceberg.rest.responses.ConfigResponse; import org.apache.iceberg.rest.responses.OAuthTokenResponse; @@ -38,12 +43,15 @@ import org.mockserver.model.Header; import org.mockserver.model.HttpRequest; import org.mockserver.model.HttpResponse; +import org.mockserver.model.Parameter; +import org.mockserver.model.ParameterBody; import org.mockserver.verify.VerificationTimes; import software.amazon.awssdk.auth.signer.internal.SignerConstant; public class TestRESTSigV4Signer { private static ClientAndServer mockServer; - private static HTTPClient client; + private static RESTClient client; + private static AuthManager authManager; @BeforeAll public static void beforeClass() { @@ -51,25 +59,36 @@ public static void beforeClass() { Map properties = ImmutableMap.of( - "rest.sigv4-enabled", - "true", + AuthProperties.AUTH_TYPE, + AuthProperties.AUTH_TYPE_SIGV4, // CI environment doesn't have credentials, but a value must be set for signing AwsProperties.REST_SIGNER_REGION, "us-west-2", AwsProperties.REST_ACCESS_KEY_ID, "id", AwsProperties.REST_SECRET_ACCESS_KEY, - "secret"); + "secret", + // OAuth2 token to test relocation of conflicting auth header + "token", + "existing_token"); + + HTTPClient httpClient = + HTTPClient.builder(properties).uri("http://localhost:" + mockServer.getLocalPort()).build(); + + authManager = AuthManagers.loadAuthManager("test", properties); + AuthSession authSession = authManager.catalogSession(httpClient, properties); + client = HTTPClient.builder(properties) .uri("http://localhost:" + mockServer.getLocalPort()) - .withHeader(HttpHeaders.AUTHORIZATION, "Bearer existing_token") - .build(); + .build() + .withAuthSession(authSession); } @AfterAll public static void afterClass() throws IOException { mockServer.stop(); + authManager.close(); client.close(); } @@ -88,7 +107,9 @@ public void signRequestWithoutBody() { .withHeader(Header.header(HttpHeaders.AUTHORIZATION, "AWS4-HMAC-SHA256.*")) // Require that conflicting auth header is relocated .withHeader( - Header.header(RESTSigV4Signer.RELOCATED_HEADER_PREFIX + HttpHeaders.AUTHORIZATION)) + Header.header( + RESTSigV4Signer.RELOCATED_HEADER_PREFIX + HttpHeaders.AUTHORIZATION, + "Bearer existing_token")) // Require the empty body checksum .withHeader( Header.header( @@ -111,11 +132,18 @@ public void signRequestWithBody() { HttpRequest.request() .withMethod("POST") .withPath("/v1/oauth/token") + .withBody( + ParameterBody.params( + Parameter.param("client_id", "asdfasd"), + Parameter.param("client_secret", "asdfasdf"), + Parameter.param("scope", "catalog"))) // Require SigV4 Authorization .withHeader(Header.header(HttpHeaders.AUTHORIZATION, "AWS4-HMAC-SHA256.*")) // Require that conflicting auth header is relocated .withHeader( - Header.header(RESTSigV4Signer.RELOCATED_HEADER_PREFIX + HttpHeaders.AUTHORIZATION)) + Header.header( + RESTSigV4Signer.RELOCATED_HEADER_PREFIX + HttpHeaders.AUTHORIZATION, + "Bearer existing_token")) // Require a body checksum is set .withHeader(Header.header(SignerConstant.X_AMZ_CONTENT_SHA256)); diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/TestVendedCredentialsProvider.java b/aws/src/test/java/org/apache/iceberg/aws/s3/TestVendedCredentialsProvider.java index 67cd1cb55241..73f4d9397334 100644 --- a/aws/src/test/java/org/apache/iceberg/aws/s3/TestVendedCredentialsProvider.java +++ b/aws/src/test/java/org/apache/iceberg/aws/s3/TestVendedCredentialsProvider.java @@ -26,7 +26,6 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; -import org.apache.iceberg.exceptions.RESTException; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.rest.HttpMethod; import org.apache.iceberg.rest.credentials.Credential; @@ -79,8 +78,8 @@ public void invalidOrMissingUri() { VendedCredentialsProvider.create( ImmutableMap.of(VendedCredentialsProvider.URI, "invalid uri"))) { assertThatThrownBy(provider::resolveCredentials) - .isInstanceOf(RESTException.class) - .hasMessageStartingWith("Failed to create request URI from base invalid uri"); + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Illegal character in path at index 7: invalid uri"); } } diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3RestSigner.java b/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3RestSigner.java index 313214c4e98f..362e0557f229 100644 --- a/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3RestSigner.java +++ b/aws/src/test/java/org/apache/iceberg/aws/s3/signer/TestS3RestSigner.java @@ -26,7 +26,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.stream.Collectors; import org.apache.iceberg.aws.s3.MinioUtil; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; @@ -112,29 +111,6 @@ public static void beforeClass() throws Exception { @AfterAll public static void afterClass() throws Exception { - assertThat(validatingSigner.icebergSigner.tokenRefreshExecutor()) - .isInstanceOf(ScheduledThreadPoolExecutor.class); - - ScheduledThreadPoolExecutor executor = - ((ScheduledThreadPoolExecutor) validatingSigner.icebergSigner.tokenRefreshExecutor()); - // token expiration is set to 10000s by the S3SignerServlet so there should be exactly one token - // scheduled for refresh. Such a high token expiration value is explicitly selected to be much - // larger than TestS3RestSigner would need to execute all tests. - // The reason why this check is done here with a high token expiration is to make sure that - // there aren't other token refreshes being scheduled after every sign request and after - // TestS3RestSigner completes all tests, there should be only this single token in the queue - // that is scheduled for refresh - assertThat(executor.getPoolSize()).isEqualTo(1); - assertThat(executor.getQueue()) - .as("should only have a single token scheduled for refresh") - .hasSize(1); - assertThat(executor.getActiveCount()) - .as("should not have any token being refreshed") - .isEqualTo(0); - assertThat(executor.getCompletedTaskCount()) - .as("should not have any expired token that required a refresh") - .isEqualTo(0); - if (null != httpServer) { httpServer.stop(); } diff --git a/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java b/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java index e83ee650cf99..477aa2fbbe4b 100644 --- a/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java +++ b/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java @@ -23,14 +23,13 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.net.URI; -import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; +import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; -import java.util.stream.Collectors; +import java.util.function.Supplier; import org.apache.hc.client5.http.auth.CredentialsProvider; -import org.apache.hc.client5.http.classic.methods.HttpUriRequest; import org.apache.hc.client5.http.classic.methods.HttpUriRequestBase; import org.apache.hc.client5.http.config.ConnectionConfig; import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; @@ -43,23 +42,19 @@ import org.apache.hc.core5.http.Header; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.HttpHost; -import org.apache.hc.core5.http.HttpRequestInterceptor; import org.apache.hc.core5.http.HttpStatus; -import org.apache.hc.core5.http.Method; import org.apache.hc.core5.http.ParseException; import org.apache.hc.core5.http.impl.EnglishReasonPhraseCatalog; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.io.entity.StringEntity; -import org.apache.hc.core5.http.message.BasicHeader; import org.apache.hc.core5.io.CloseMode; -import org.apache.hc.core5.net.URIBuilder; import org.apache.iceberg.IcebergBuild; -import org.apache.iceberg.common.DynConstructors; -import org.apache.iceberg.common.DynMethods; import org.apache.iceberg.exceptions.RESTException; import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.responses.ErrorResponse; import org.apache.iceberg.util.PropertyUtil; import org.slf4j.Logger; @@ -69,9 +64,6 @@ public class HTTPClient implements RESTClient { private static final Logger LOG = LoggerFactory.getLogger(HTTPClient.class); - private static final String SIGV4_ENABLED = "rest.sigv4-enabled"; - private static final String SIGV4_REQUEST_INTERCEPTOR_IMPL = - "org.apache.iceberg.aws.RESTSigV4Signer"; @VisibleForTesting static final String CLIENT_VERSION_HEADER = "X-Client-Version"; @VisibleForTesting @@ -88,37 +80,31 @@ public class HTTPClient implements RESTClient { @VisibleForTesting static final String REST_SOCKET_TIMEOUT_MS = "rest.client.socket-timeout-ms"; - private final String uri; + private final URI baseUri; private final CloseableHttpClient httpClient; + private final Map baseHeaders; private final ObjectMapper mapper; + private final AuthSession authSession; + private final boolean closeClient; private HTTPClient( - String uri, + URI baseUri, HttpHost proxy, CredentialsProvider proxyCredsProvider, Map baseHeaders, ObjectMapper objectMapper, - HttpRequestInterceptor requestInterceptor, Map properties, HttpClientConnectionManager connectionManager) { - this.uri = uri; + this.baseUri = baseUri; + this.baseHeaders = baseHeaders; this.mapper = objectMapper; + this.authSession = AuthSession.EMPTY; + this.closeClient = true; HttpClientBuilder clientBuilder = HttpClients.custom(); clientBuilder.setConnectionManager(connectionManager); - if (baseHeaders != null) { - clientBuilder.setDefaultHeaders( - baseHeaders.entrySet().stream() - .map(e -> new BasicHeader(e.getKey(), e.getValue())) - .collect(Collectors.toList())); - } - - if (requestInterceptor != null) { - clientBuilder.addRequestInterceptorLast(requestInterceptor); - } - int maxRetries = PropertyUtil.propertyAsInt(properties, REST_MAX_RETRIES, 5); clientBuilder.setRetryStrategy(new ExponentialHttpRequestRetryStrategy(maxRetries)); @@ -133,6 +119,24 @@ private HTTPClient( this.httpClient = clientBuilder.build(); } + /** + * Constructor for creating a child HTTPClient associated with an AuthSession. The returned child + * shares the same base uri, mapper, and HTTP client as the parent. + */ + private HTTPClient(HTTPClient parent, AuthSession authSession) { + this.baseUri = parent.baseUri; + this.httpClient = parent.httpClient; + this.mapper = parent.mapper; + this.baseHeaders = parent.baseHeaders; + this.authSession = authSession; + this.closeClient = false; + } + + @Override + public HTTPClient withAuthSession(AuthSession session) { + return new HTTPClient(this, session); + } + private static String extractResponseBodyAsString(CloseableHttpResponse response) { try { if (response.getEntity() == null) { @@ -214,92 +218,62 @@ private static void throwFailure( throw new RESTException("Unhandled error: %s", errorResponse); } - private URI buildUri(String path, Map params) { - // if full path is provided, use the input path as path - if (path.startsWith("/")) { - throw new RESTException( - "Received a malformed path for a REST request: %s. Paths should not start with /", path); - } - String fullPath = - (path.startsWith("https://") || path.startsWith("http://")) - ? path - : String.format("%s/%s", uri, path); - try { - URIBuilder builder = new URIBuilder(fullPath); - if (params != null) { - params.forEach(builder::addParameter); - } - return builder.build(); - } catch (URISyntaxException e) { - throw new RESTException( - "Failed to create request URI from base %s, params %s", fullPath, params); - } - } - - /** - * Method to execute an HTTP request and process the corresponding response. - * - * @param method - HTTP method, such as GET, POST, HEAD, etc. - * @param queryParams - A map of query parameters - * @param path - URI to send the request to - * @param requestBody - Content to place in the request body - * @param responseType - Class of the Response type. Needs to have serializer registered with - * ObjectMapper - * @param errorHandler - Error handler delegated for HTTP responses which handles server error - * responses - * @param - Class type of the response for deserialization. Must be registered with the - * ObjectMapper. - * @return The response entity, parsed and converted to its type T - */ - private T execute( - Method method, + private HTTPRequest buildRequest( + HTTPMethod method, String path, Map queryParams, - Object requestBody, - Class responseType, Map headers, - Consumer errorHandler) { - return execute( - method, path, queryParams, requestBody, responseType, headers, errorHandler, h -> {}); + Object body) { + + ImmutableHTTPRequest.Builder builder = + ImmutableHTTPRequest.builder() + .baseUri(baseUri) + .mapper(mapper) + .method(method) + .path(path) + .body(body) + .queryParameters(queryParams == null ? Map.of() : queryParams); + + Map> allHeaders = Maps.newLinkedHashMap(); + if (headers != null) { + headers.forEach((name, value) -> allHeaders.put(name, List.of(value))); + } + + allHeaders.putIfAbsent(HttpHeaders.ACCEPT, List.of(ContentType.APPLICATION_JSON.getMimeType())); + + // Many systems require that content type is set regardless and will fail, + // even on an empty bodied request. + // Encode maps as form data (application/x-www-form-urlencoded), + // and other requests are assumed to contain JSON bodies (application/json). + ContentType mimeType = + body instanceof Map + ? ContentType.APPLICATION_FORM_URLENCODED + : ContentType.APPLICATION_JSON; + allHeaders.putIfAbsent(HttpHeaders.CONTENT_TYPE, List.of(mimeType.getMimeType())); + + // Apply base headers now to mimic the behavior of + // org.apache.hc.client5.http.protocol.RequestDefaultHeaders + // We want these headers applied *before* the AuthSession authenticates the request. + if (baseHeaders != null) { + baseHeaders.forEach((name, value) -> allHeaders.putIfAbsent(name, List.of(value))); + } + + return authSession.authenticate(builder.headers(allHeaders).build()); } - /** - * Method to execute an HTTP request and process the corresponding response. - * - * @param method - HTTP method, such as GET, POST, HEAD, etc. - * @param queryParams - A map of query parameters - * @param path - URL to send the request to - * @param requestBody - Content to place in the request body - * @param responseType - Class of the Response type. Needs to have serializer registered with - * ObjectMapper - * @param errorHandler - Error handler delegated for HTTP responses which handles server error - * responses - * @param responseHeaders The consumer of the response headers - * @param - Class type of the response for deserialization. Must be registered with the - * ObjectMapper. - * @return The response entity, parsed and converted to its type T - */ - private T execute( - Method method, - String path, - Map queryParams, - Object requestBody, + private T execute( + HTTPRequest req, Class responseType, - Map headers, Consumer errorHandler, Consumer> responseHeaders) { - HttpUriRequestBase request = new HttpUriRequestBase(method.name(), buildUri(path, queryParams)); - - if (requestBody instanceof Map) { - // encode maps as form data, application/x-www-form-urlencoded - addRequestHeaders(request, headers, ContentType.APPLICATION_FORM_URLENCODED.getMimeType()); - request.setEntity(toFormEncoding((Map) requestBody)); - } else if (requestBody != null) { - // other request bodies are serialized as JSON, application/json - addRequestHeaders(request, headers, ContentType.APPLICATION_JSON.getMimeType()); - request.setEntity(toJson(requestBody)); - } else { - addRequestHeaders(request, headers, ContentType.APPLICATION_JSON.getMimeType()); + HttpUriRequestBase request = new HttpUriRequestBase(req.method().name(), req.requestUri()); + + req.headers() + .forEach((name, values) -> values.forEach(value -> request.addHeader(name, value))); + + String encodedBody = req.encodedBody(); + if (encodedBody != null) { + request.setEntity(new StringEntity(encodedBody)); } try (CloseableHttpResponse response = httpClient.execute(request)) { @@ -326,7 +300,7 @@ private T execute( if (responseBody == null) { throw new RESTException( "Invalid (null) response body for request (expected %s): method=%s, path=%s, status=%d", - responseType.getSimpleName(), method.name(), path, response.getCode()); + responseType.getSimpleName(), req.method(), req.path(), response.getCode()); } try { @@ -339,13 +313,94 @@ private T execute( responseType.getSimpleName()); } } catch (IOException e) { - throw new RESTException(e, "Error occurred while processing %s request", method); + throw new RESTException(e, "Error occurred while processing %s request", req.method()); } } + @Override + public void head( + String path, Supplier> headers, Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.HEAD, path, null, headers.get(), null); + execute(request, null, errorHandler, h -> {}); + } + @Override public void head(String path, Map headers, Consumer errorHandler) { - execute(Method.HEAD, path, null, null, null, headers, errorHandler); + HTTPRequest request = buildRequest(HTTPMethod.HEAD, path, null, headers, null); + execute(request, null, errorHandler, h -> {}); + } + + @Override + public T delete( + String path, + Map queryParams, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.DELETE, path, queryParams, headers.get(), null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T delete( + String path, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.DELETE, path, null, headers.get(), null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T delete( + String path, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.DELETE, path, null, headers, null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T delete( + String path, + Map queryParams, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.DELETE, path, queryParams, headers, null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T get( + String path, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.GET, path, null, headers.get(), null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T get( + String path, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.GET, path, null, headers, null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T get( + String path, + Map queryParams, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.GET, path, queryParams, headers.get(), null); + return execute(request, responseType, errorHandler, h -> {}); } @Override @@ -355,7 +410,8 @@ public T get( Class responseType, Map headers, Consumer errorHandler) { - return execute(Method.GET, path, queryParams, null, responseType, headers, errorHandler); + HTTPRequest request = buildRequest(HTTPMethod.GET, path, queryParams, headers, null); + return execute(request, responseType, errorHandler, h -> {}); } @Override @@ -363,9 +419,10 @@ public T post( String path, RESTRequest body, Class responseType, - Map headers, + Supplier> headers, Consumer errorHandler) { - return execute(Method.POST, path, null, body, responseType, headers, errorHandler); + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers.get(), body); + return execute(request, responseType, errorHandler, h -> {}); } @Override @@ -373,30 +430,34 @@ public T post( String path, RESTRequest body, Class responseType, - Map headers, + Supplier> headers, Consumer errorHandler, Consumer> responseHeaders) { - return execute( - Method.POST, path, null, body, responseType, headers, errorHandler, responseHeaders); + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers.get(), body); + return execute(request, responseType, errorHandler, responseHeaders); } @Override - public T delete( + public T post( String path, + RESTRequest body, Class responseType, Map headers, - Consumer errorHandler) { - return execute(Method.DELETE, path, null, null, responseType, headers, errorHandler); + Consumer errorHandler, + Consumer> responseHeaders) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers, body); + return execute(request, responseType, errorHandler, responseHeaders); } @Override - public T delete( + public T post( String path, - Map queryParams, + RESTRequest body, Class responseType, Map headers, Consumer errorHandler) { - return execute(Method.DELETE, path, queryParams, null, responseType, headers, errorHandler); + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers, body); + return execute(request, responseType, errorHandler, h -> {}); } @Override @@ -404,58 +465,32 @@ public T postForm( String path, Map formData, Class responseType, - Map headers, + Supplier> headers, Consumer errorHandler) { - return execute(Method.POST, path, null, formData, responseType, headers, errorHandler); + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers.get(), formData); + return execute(request, responseType, errorHandler, h -> {}); } - private void addRequestHeaders( - HttpUriRequest request, Map requestHeaders, String bodyMimeType) { - request.setHeader(HttpHeaders.ACCEPT, ContentType.APPLICATION_JSON.getMimeType()); - // Many systems require that content type is set regardless and will fail, even on an empty - // bodied request. - request.setHeader(HttpHeaders.CONTENT_TYPE, bodyMimeType); - requestHeaders.forEach(request::setHeader); + @Override + public T postForm( + String path, + Map formData, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers, formData); + return execute(request, responseType, errorHandler, h -> {}); } @Override public void close() throws IOException { - httpClient.close(CloseMode.GRACEFUL); - } - - @VisibleForTesting - static HttpRequestInterceptor loadInterceptorDynamically( - String impl, Map properties) { - HttpRequestInterceptor instance; - - DynConstructors.Ctor ctor; try { - ctor = - DynConstructors.builder(HttpRequestInterceptor.class) - .loader(HTTPClient.class.getClassLoader()) - .impl(impl) - .buildChecked(); - } catch (NoSuchMethodException e) { - throw new IllegalArgumentException( - String.format( - "Cannot initialize RequestInterceptor, missing no-arg constructor: %s", impl), - e); - } - - try { - instance = ctor.newInstance(); - } catch (ClassCastException e) { - throw new IllegalArgumentException( - String.format("Cannot initialize, %s does not implement RequestInterceptor", impl), e); + authSession.close(); + } finally { + if (closeClient) { + httpClient.close(CloseMode.GRACEFUL); + } } - - DynMethods.builder("initialize") - .hiddenImpl(impl, Map.class) - .orNoop() - .build(instance) - .invoke(properties); - - return instance; } static HttpClientConnectionManager configureConnectionManager(Map properties) { @@ -506,7 +541,7 @@ public static Builder builder(Map properties) { public static class Builder { private final Map properties; private final Map baseHeaders = Maps.newHashMap(); - private String uri; + private URI uri; private ObjectMapper mapper = RESTObjectMapper.mapper(); private HttpHost proxy; private CredentialsProvider proxyCredentialsProvider; @@ -515,9 +550,15 @@ private Builder(Map properties) { this.properties = properties; } - public Builder uri(String path) { - Preconditions.checkNotNull(path, "Invalid uri for http client: null"); - this.uri = RESTUtil.stripTrailingSlash(path); + public Builder uri(String baseUri) { + Preconditions.checkNotNull(baseUri, "Invalid uri for http client: null"); + this.uri = URI.create(RESTUtil.stripTrailingSlash(baseUri)); + return this; + } + + public Builder uri(URI baseUri) { + Preconditions.checkNotNull(baseUri, "Invalid uri for http client: null"); + this.uri = baseUri; return this; } @@ -553,12 +594,6 @@ public HTTPClient build() { withHeader(CLIENT_VERSION_HEADER, IcebergBuild.fullVersion()); withHeader(CLIENT_GIT_COMMIT_SHORT_HEADER, IcebergBuild.gitCommitShortId()); - HttpRequestInterceptor interceptor = null; - - if (PropertyUtil.propertyAsBoolean(properties, SIGV4_ENABLED, false)) { - interceptor = loadInterceptorDynamically(SIGV4_REQUEST_INTERCEPTOR_IMPL, properties); - } - if (this.proxyCredentialsProvider != null) { Preconditions.checkNotNull( proxy, "Invalid http client proxy for proxy credentials provider: null"); @@ -570,21 +605,8 @@ public HTTPClient build() { proxyCredentialsProvider, baseHeaders, mapper, - interceptor, properties, configureConnectionManager(properties)); } } - - private StringEntity toJson(Object requestBody) { - try { - return new StringEntity(mapper.writeValueAsString(requestBody), StandardCharsets.UTF_8); - } catch (JsonProcessingException e) { - throw new RESTException(e, "Failed to write request body: %s", requestBody); - } - } - - private StringEntity toFormEncoding(Map formData) { - return new StringEntity(RESTUtil.encodeFormData(formData), StandardCharsets.UTF_8); - } } diff --git a/core/src/main/java/org/apache/iceberg/rest/HTTPRequest.java b/core/src/main/java/org/apache/iceberg/rest/HTTPRequest.java new file mode 100644 index 000000000000..1adfae87ec1b --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/HTTPRequest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest; + +import com.fasterxml.jackson.databind.ObjectMapper; +import java.net.URI; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.immutables.value.Value; + +/** Represents an HTTP request. */ +@Value.Style(redactedMask = "****", depluralize = true) +@Value.Immutable +@SuppressWarnings({"ImmutablesStyle", "SafeLoggingPropagation"}) +public interface HTTPRequest { + + enum HTTPMethod { + GET, + HEAD, + POST, + DELETE + } + + /** + * Returns the base URI configured at the REST client level. The base URI is used to construct the + * full {@link #requestUri()}. + */ + URI baseUri(); + + /** + * Returns the full URI of this request. The URI is constructed from the base URI, path, and query + * parameters. It cannot be modified directly. + */ + @Value.Lazy + default URI requestUri() { + return RESTUtil.buildRequestUri(this); + } + + /** Returns the HTTP method of this request. */ + HTTPMethod method(); + + /** Returns the path of this request. */ + String path(); + + /** Returns the query parameters of this request. */ + Map queryParameters(); + + /** Returns all the headers of this request. The map is case-sensitive! */ + @Value.Redacted + Map> headers(); + + /** Returns the header values of the given name. */ + default List headers(String name) { + return headers().getOrDefault(name, List.of()); + } + + /** Returns whether the request contains a header with the given name. */ + default boolean containsHeader(String name) { + return !headers(name).isEmpty(); + } + + /** Returns the raw, unencoded request body. */ + @Nullable + @Value.Redacted + Object body(); + + /** Returns the encoded request body as a string. */ + @Value.Lazy + @Nullable + @Value.Redacted + default String encodedBody() { + return RESTUtil.encodeRequestBody(this); + } + + /** + * Returns the {@link ObjectMapper} to use for encoding the request body. The default is {@link + * RESTObjectMapper#mapper()}. + */ + @Value.Default + default ObjectMapper mapper() { + return RESTObjectMapper.mapper(); + } + + default HTTPRequest putHeadersIfAbsent(Map headers) { + Map> newHeaders = Maps.newLinkedHashMap(headers()); + headers.forEach((name, value) -> newHeaders.putIfAbsent(name, List.of(value))); + return ImmutableHTTPRequest.builder().from(this).headers(newHeaders).build(); + } +} diff --git a/core/src/main/java/org/apache/iceberg/rest/RESTClient.java b/core/src/main/java/org/apache/iceberg/rest/RESTClient.java index 0f17d9a127e2..2843972fee45 100644 --- a/core/src/main/java/org/apache/iceberg/rest/RESTClient.java +++ b/core/src/main/java/org/apache/iceberg/rest/RESTClient.java @@ -23,6 +23,7 @@ import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.responses.ErrorResponse; /** Interface for a basic HTTP Client for interfacing with the REST catalog. */ @@ -158,4 +159,9 @@ T postForm( Class responseType, Map headers, Consumer errorHandler); + + /** Returns a REST client that authenticates requests using the given session. */ + default RESTClient withAuthSession(AuthSession session) { + return this; + } } diff --git a/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java b/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java index 5c6fc49984a5..8d7282c21de5 100644 --- a/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java +++ b/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java @@ -18,23 +18,15 @@ */ package org.apache.iceberg.rest; -import com.github.benmanes.caffeine.cache.Cache; -import com.github.benmanes.caffeine.cache.Caffeine; -import com.github.benmanes.caffeine.cache.RemovalListener; import java.io.Closeable; import java.io.IOException; import java.io.UncheckedIOException; -import java.time.Duration; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; -import java.util.concurrent.Future; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.function.Function; -import java.util.function.Supplier; import org.apache.iceberg.BaseTable; import org.apache.iceberg.CatalogProperties; import org.apache.iceberg.CatalogUtil; @@ -70,10 +62,9 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; -import org.apache.iceberg.rest.auth.AuthConfig; -import org.apache.iceberg.rest.auth.OAuth2Properties; -import org.apache.iceberg.rest.auth.OAuth2Util; -import org.apache.iceberg.rest.auth.OAuth2Util.AuthSession; +import org.apache.iceberg.rest.auth.AuthManager; +import org.apache.iceberg.rest.auth.AuthManagers; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.requests.CommitTransactionRequest; import org.apache.iceberg.rest.requests.CreateNamespaceRequest; import org.apache.iceberg.rest.requests.CreateTableRequest; @@ -91,12 +82,9 @@ import org.apache.iceberg.rest.responses.ListTablesResponse; import org.apache.iceberg.rest.responses.LoadTableResponse; import org.apache.iceberg.rest.responses.LoadViewResponse; -import org.apache.iceberg.rest.responses.OAuthTokenResponse; import org.apache.iceberg.rest.responses.UpdateNamespacePropertiesResponse; import org.apache.iceberg.util.EnvironmentUtil; -import org.apache.iceberg.util.Pair; import org.apache.iceberg.util.PropertyUtil; -import org.apache.iceberg.util.ThreadPools; import org.apache.iceberg.view.BaseView; import org.apache.iceberg.view.ImmutableSQLViewRepresentation; import org.apache.iceberg.view.ImmutableViewVersion; @@ -106,12 +94,9 @@ import org.apache.iceberg.view.ViewRepresentation; import org.apache.iceberg.view.ViewUtil; import org.apache.iceberg.view.ViewVersion; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class RESTSessionCatalog extends BaseViewSessionCatalog implements Configurable, Closeable { - private static final Logger LOG = LoggerFactory.getLogger(RESTSessionCatalog.class); private static final String DEFAULT_FILE_IO_IMPL = "org.apache.iceberg.io.ResolvingFileIO"; private static final String REST_METRICS_REPORTING_ENABLED = "rest-metrics-reporting-enabled"; private static final String REST_SNAPSHOT_LOADING_MODE = "snapshot-loading-mode"; @@ -119,20 +104,6 @@ public class RESTSessionCatalog extends BaseViewSessionCatalog // server supports view endpoints but doesn't send the "endpoints" field in the ConfigResponse static final String VIEW_ENDPOINTS_SUPPORTED = "view-endpoints-supported"; public static final String REST_PAGE_SIZE = "rest-page-size"; - private static final List TOKEN_PREFERENCE_ORDER = - ImmutableList.of( - OAuth2Properties.ID_TOKEN_TYPE, - OAuth2Properties.ACCESS_TOKEN_TYPE, - OAuth2Properties.JWT_TOKEN_TYPE, - OAuth2Properties.SAML2_TOKEN_TYPE, - OAuth2Properties.SAML1_TOKEN_TYPE); - - // Auth-related properties that are allowed to be passed to the table session - private static final Set TABLE_SESSION_ALLOW_LIST = - ImmutableSet.builder() - .add(OAuth2Properties.TOKEN) - .addAll(TOKEN_PREFERENCE_ORDER) - .build(); private static final Set DEFAULT_ENDPOINTS = ImmutableSet.builder() @@ -163,11 +134,9 @@ public class RESTSessionCatalog extends BaseViewSessionCatalog private final Function, RESTClient> clientBuilder; private final BiFunction, FileIO> ioBuilder; - private Cache sessions = null; - private Cache tableSessions = null; private FileIOTracker fileIOTracker = null; private AuthSession catalogAuth = null; - private boolean keepTokenRefreshed = true; + private AuthManager authManager; private RESTClient client = null; private ResourcePaths paths = null; private SnapshotMode snapshotMode = null; @@ -179,9 +148,6 @@ public class RESTSessionCatalog extends BaseViewSessionCatalog private CloseableGroup closeables = null; private Set endpoints; - // a lazy thread pool for token refresh - private volatile ScheduledExecutorService refreshExecutor = null; - enum SnapshotMode { ALL, REFS; @@ -192,7 +158,14 @@ Map params() { } public RESTSessionCatalog() { - this(config -> HTTPClient.builder(config).uri(config.get(CatalogProperties.URI)).build(), null); + this( + config -> { + HTTPClient.Builder builder = + HTTPClient.builder(config).uri(config.get(CatalogProperties.URI)); + configHeaders(config).forEach(builder::withHeader); + return builder.build(); + }, + null); } public RESTSessionCatalog( @@ -203,7 +176,6 @@ public RESTSessionCatalog( this.ioBuilder = ioBuilder; } - @SuppressWarnings("checkstyle:CyclomaticComplexity") @Override public void initialize(String name, Map unresolved) { Preconditions.checkArgument(unresolved != null, "Invalid configuration: null"); @@ -212,54 +184,18 @@ public void initialize(String name, Map unresolved) { // catalog service Map props = EnvironmentUtil.resolveAll(unresolved); - long startTimeMillis = - System.currentTimeMillis(); // keep track of the init start time for token refresh - String initToken = props.get(OAuth2Properties.TOKEN); - boolean hasInitToken = initToken != null; + this.authManager = AuthManagers.loadAuthManager(name, props); - // fetch auth and config to complete initialization ConfigResponse config; - OAuthTokenResponse authResponse; - String credential = props.get(OAuth2Properties.CREDENTIAL); - boolean hasCredential = credential != null && !credential.isEmpty(); - String scope = props.getOrDefault(OAuth2Properties.SCOPE, OAuth2Properties.CATALOG_SCOPE); - Map optionalOAuthParams = OAuth2Util.buildOptionalParam(props); - if (!props.containsKey(OAuth2Properties.OAUTH2_SERVER_URI) - && (hasInitToken || hasCredential) - && !PropertyUtil.propertyAsBoolean(props, "rest.sigv4-enabled", false)) { - LOG.warn( - "Iceberg REST client is missing the OAuth2 server URI configuration and defaults to {}/{}. " - + "This automatic fallback will be removed in a future Iceberg release." - + "It is recommended to configure the OAuth2 endpoint using the '{}' property to be prepared. " - + "This warning will disappear if the OAuth2 endpoint is explicitly configured. " - + "See https://github.com/apache/iceberg/issues/10537", - RESTUtil.stripTrailingSlash(props.get(CatalogProperties.URI)), - ResourcePaths.tokens(), - OAuth2Properties.OAUTH2_SERVER_URI); - } - String oauth2ServerUri = - props.getOrDefault(OAuth2Properties.OAUTH2_SERVER_URI, ResourcePaths.tokens()); - try (RESTClient initClient = clientBuilder.apply(props)) { - Map initHeaders = - RESTUtil.merge(configHeaders(props), OAuth2Util.authHeaders(initToken)); - if (hasCredential) { - authResponse = - OAuth2Util.fetchToken( - initClient, initHeaders, credential, scope, oauth2ServerUri, optionalOAuthParams); - Map authHeaders = - RESTUtil.merge(initHeaders, OAuth2Util.authHeaders(authResponse.token())); - config = fetchConfig(initClient, authHeaders, props); - } else { - authResponse = null; - config = fetchConfig(initClient, initHeaders, props); - } + try (RESTClient initClient = clientBuilder.apply(props); + AuthSession initSession = authManager.initSession(initClient, props)) { + config = fetchConfig(initClient, initSession, props); } catch (IOException e) { throw new UncheckedIOException("Failed to close HTTP client", e); } // build the final configuration and set up the catalog's auth Map mergedProps = config.merge(props); - Map baseHeaders = configHeaders(mergedProps); if (config.endpoints().isEmpty()) { this.endpoints = @@ -273,35 +209,10 @@ public void initialize(String name, Map unresolved) { this.endpoints = ImmutableSet.copyOf(config.endpoints()); } - this.sessions = newSessionCache(mergedProps); - this.tableSessions = newSessionCache(mergedProps); - this.keepTokenRefreshed = - PropertyUtil.propertyAsBoolean( - mergedProps, - OAuth2Properties.TOKEN_REFRESH_ENABLED, - OAuth2Properties.TOKEN_REFRESH_ENABLED_DEFAULT); this.client = clientBuilder.apply(mergedProps); this.paths = ResourcePaths.forCatalogProperties(mergedProps); - String token = mergedProps.get(OAuth2Properties.TOKEN); - this.catalogAuth = - new AuthSession( - baseHeaders, - AuthConfig.builder() - .credential(credential) - .scope(scope) - .oauth2ServerUri(oauth2ServerUri) - .optionalOAuthParams(optionalOAuthParams) - .build()); - if (authResponse != null) { - this.catalogAuth = - AuthSession.fromTokenResponse( - client, tokenRefreshExecutor(name), authResponse, startTimeMillis, catalogAuth); - } else if (token != null) { - this.catalogAuth = - AuthSession.fromAccessToken( - client, tokenRefreshExecutor(name), token, expiresAtMillis(mergedProps), catalogAuth); - } + this.catalogAuth = authManager.catalogSession(client, mergedProps); this.pageSize = PropertyUtil.propertyAsNullableInt(mergedProps, REST_PAGE_SIZE); if (pageSize != null) { @@ -313,6 +224,8 @@ public void initialize(String name, Map unresolved) { this.fileIOTracker = new FileIOTracker(); this.closeables = new CloseableGroup(); + this.closeables.addCloseable(this.catalogAuth); + this.closeables.addCloseable(this.authManager); this.closeables.addCloseable(this.io); this.closeables.addCloseable(this.client); this.closeables.addCloseable(fileIOTracker); @@ -331,27 +244,6 @@ public void initialize(String name, Map unresolved) { super.initialize(name, mergedProps); } - private AuthSession session(SessionContext context) { - AuthSession session = - sessions.get( - context.sessionId(), - id -> { - Pair> newSession = - newSession(context.credentials(), context.properties(), catalogAuth); - if (null != newSession) { - return newSession.second().get(); - } - - return null; - }); - - return session != null ? session : catalogAuth; - } - - private Supplier> headers(SessionContext context) { - return session(context)::headers; - } - @Override public void setConf(Object newConf) { this.conf = newConf; @@ -373,13 +265,16 @@ public List listTables(SessionContext context, Namespace ns) { do { queryParams.put("pageToken", pageToken); + AuthSession authSession = authManager.contextualSession(context, catalogAuth); ListTablesResponse response = - client.get( - paths.tables(ns), - queryParams, - ListTablesResponse.class, - headers(context), - ErrorHandlers.namespaceErrorHandler()); + client + .withAuthSession(authSession) + .get( + paths.tables(ns), + queryParams, + ListTablesResponse.class, + Map.of(), + ErrorHandlers.namespaceErrorHandler()); pageToken = response.nextPageToken(); tables.addAll(response.identifiers()); } while (pageToken != null); @@ -393,8 +288,10 @@ public boolean dropTable(SessionContext context, TableIdentifier identifier) { checkIdentifierIsValid(identifier); try { - client.delete( - paths.table(identifier), null, headers(context), ErrorHandlers.tableErrorHandler()); + AuthSession authSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(authSession) + .delete(paths.table(identifier), null, Map.of(), ErrorHandlers.tableErrorHandler()); return true; } catch (NoSuchTableException e) { return false; @@ -407,12 +304,15 @@ public boolean purgeTable(SessionContext context, TableIdentifier identifier) { checkIdentifierIsValid(identifier); try { - client.delete( - paths.table(identifier), - ImmutableMap.of("purgeRequested", "true"), - null, - headers(context), - ErrorHandlers.tableErrorHandler()); + AuthSession authSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(authSession) + .delete( + paths.table(identifier), + ImmutableMap.of("purgeRequested", "true"), + null, + Map.of(), + ErrorHandlers.tableErrorHandler()); return true; } catch (NoSuchTableException e) { return false; @@ -429,7 +329,10 @@ public void renameTable(SessionContext context, TableIdentifier from, TableIdent RenameTableRequest.builder().withSource(from).withDestination(to).build(); // for now, ignore the response because there is no way to return it - client.post(paths.rename(), request, null, headers(context), ErrorHandlers.tableErrorHandler()); + AuthSession authSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(authSession) + .post(paths.rename(), request, null, Map.of(), ErrorHandlers.tableErrorHandler()); } @Override @@ -437,7 +340,10 @@ public boolean tableExists(SessionContext context, TableIdentifier identifier) { checkIdentifierIsValid(identifier); try { - client.head(paths.table(identifier), headers(context), ErrorHandlers.tableErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(contextualSession) + .head(paths.table(identifier), Map.of(), ErrorHandlers.tableErrorHandler()); return true; } catch (NoSuchTableException e) { return false; @@ -447,12 +353,15 @@ public boolean tableExists(SessionContext context, TableIdentifier identifier) { private LoadTableResponse loadInternal( SessionContext context, TableIdentifier identifier, SnapshotMode mode) { Endpoint.check(endpoints, Endpoint.V1_LOAD_TABLE); - return client.get( - paths.table(identifier), - mode.params(), - LoadTableResponse.class, - headers(context), - ErrorHandlers.tableErrorHandler()); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + return client + .withAuthSession(contextualSession) + .get( + paths.table(identifier), + mode.params(), + LoadTableResponse.class, + Map.of(), + ErrorHandlers.tableErrorHandler()); } @Override @@ -494,7 +403,10 @@ public Table loadTable(SessionContext context, TableIdentifier identifier) { } TableIdentifier finalIdentifier = loadedIdent; - AuthSession session = tableSession(response.config(), session(context)); + Map tableConf = response.config(); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); + AuthSession tableSession = + authManager.tableSession(finalIdentifier, tableConf, contextualSession); TableMetadata tableMetadata; if (snapshotMode == SnapshotMode.REFS) { @@ -513,11 +425,12 @@ public Table loadTable(SessionContext context, TableIdentifier identifier) { tableMetadata = response.tableMetadata(); } + RESTClient tableClient = client.withAuthSession(tableSession); RESTTableOperations ops = new RESTTableOperations( - client, + tableClient, paths.table(finalIdentifier), - session::headers, + Map::of, tableFileIO(context, response.config()), tableMetadata, endpoints); @@ -528,7 +441,7 @@ public Table loadTable(SessionContext context, TableIdentifier identifier) { new BaseTable( ops, fullTableName(finalIdentifier), - metricsReporter(paths.metrics(finalIdentifier), session::headers)); + metricsReporter(paths.metrics(finalIdentifier), tableClient)); if (metadataType != null) { return MetadataTableUtils.createMetadataTableInstance(table, metadataType); } @@ -542,11 +455,10 @@ private void trackFileIO(RESTTableOperations ops) { } } - private MetricsReporter metricsReporter( - String metricsEndpoint, Supplier> headers) { + private MetricsReporter metricsReporter(String metricsEndpoint, RESTClient restClient) { if (reportingViaRestEnabled && endpoints.contains(Endpoint.V1_REPORT_METRICS)) { RESTMetricsReporter restMetricsReporter = - new RESTMetricsReporter(client, metricsEndpoint, headers); + new RESTMetricsReporter(restClient, metricsEndpoint, Map::of); return MetricsReporters.combine(reporter, restMetricsReporter); } else { return this.reporter; @@ -579,20 +491,25 @@ public Table registerTable( .metadataLocation(metadataFileLocation) .build(); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); LoadTableResponse response = - client.post( - paths.register(ident.namespace()), - request, - LoadTableResponse.class, - headers(context), - ErrorHandlers.tableErrorHandler()); - - AuthSession session = tableSession(response.config(), session(context)); + client + .withAuthSession(contextualSession) + .post( + paths.register(ident.namespace()), + request, + LoadTableResponse.class, + Map.of(), + ErrorHandlers.tableErrorHandler()); + + Map tableConf = response.config(); + AuthSession tableSession = authManager.tableSession(ident, tableConf, contextualSession); + RESTClient tableClient = client.withAuthSession(tableSession); RESTTableOperations ops = new RESTTableOperations( - client, + tableClient, paths.table(ident), - session::headers, + Map::of, tableFileIO(context, response.config()), response.tableMetadata(), endpoints); @@ -600,7 +517,7 @@ public Table registerTable( trackFileIO(ops); return new BaseTable( - ops, fullTableName(ident), metricsReporter(paths.metrics(ident), session::headers)); + ops, fullTableName(ident), metricsReporter(paths.metrics(ident), tableClient)); } @Override @@ -611,12 +528,15 @@ public void createNamespace( CreateNamespaceRequest.builder().withNamespace(namespace).setProperties(metadata).build(); // for now, ignore the response because there is no way to return it - client.post( - paths.namespaces(), - request, - CreateNamespaceResponse.class, - headers(context), - ErrorHandlers.namespaceErrorHandler()); + AuthSession authSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(authSession) + .post( + paths.namespaces(), + request, + CreateNamespaceResponse.class, + Map.of(), + ErrorHandlers.namespaceErrorHandler()); } @Override @@ -638,13 +558,16 @@ public List listNamespaces(SessionContext context, Namespace namespac do { queryParams.put("pageToken", pageToken); + AuthSession authSession = authManager.contextualSession(context, catalogAuth); ListNamespacesResponse response = - client.get( - paths.namespaces(), - queryParams, - ListNamespacesResponse.class, - headers(context), - ErrorHandlers.namespaceErrorHandler()); + client + .withAuthSession(authSession) + .get( + paths.namespaces(), + queryParams, + ListNamespacesResponse.class, + Map.of(), + ErrorHandlers.namespaceErrorHandler()); pageToken = response.nextPageToken(); namespaces.addAll(response.namespaces()); } while (pageToken != null); @@ -658,12 +581,15 @@ public Map loadNamespaceMetadata(SessionContext context, Namespa checkNamespaceIsValid(ns); // TODO: rename to LoadNamespaceResponse? + AuthSession authSession = authManager.contextualSession(context, catalogAuth); GetNamespaceResponse response = - client.get( - paths.namespace(ns), - GetNamespaceResponse.class, - headers(context), - ErrorHandlers.namespaceErrorHandler()); + client + .withAuthSession(authSession) + .get( + paths.namespace(ns), + GetNamespaceResponse.class, + Map.of(), + ErrorHandlers.namespaceErrorHandler()); return response.properties(); } @@ -673,8 +599,10 @@ public boolean dropNamespace(SessionContext context, Namespace ns) { checkNamespaceIsValid(ns); try { - client.delete( - paths.namespace(ns), null, headers(context), ErrorHandlers.namespaceErrorHandler()); + AuthSession authSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(authSession) + .delete(paths.namespace(ns), null, Map.of(), ErrorHandlers.namespaceErrorHandler()); return true; } catch (NoSuchNamespaceException e) { return false; @@ -690,66 +618,27 @@ public boolean updateNamespaceMetadata( UpdateNamespacePropertiesRequest request = UpdateNamespacePropertiesRequest.builder().updateAll(updates).removeAll(removals).build(); + AuthSession authSession = authManager.contextualSession(context, catalogAuth); UpdateNamespacePropertiesResponse response = - client.post( - paths.namespaceProperties(ns), - request, - UpdateNamespacePropertiesResponse.class, - headers(context), - ErrorHandlers.namespaceErrorHandler()); + client + .withAuthSession(authSession) + .post( + paths.namespaceProperties(ns), + request, + UpdateNamespacePropertiesResponse.class, + Map.of(), + ErrorHandlers.namespaceErrorHandler()); return !response.updated().isEmpty(); } - private ScheduledExecutorService tokenRefreshExecutor(String catalogName) { - if (!keepTokenRefreshed) { - return null; - } - - if (refreshExecutor == null) { - synchronized (this) { - if (refreshExecutor == null) { - this.refreshExecutor = ThreadPools.newScheduledPool(catalogName + "-token-refresh", 1); - } - } - } - - return refreshExecutor; - } - @Override public void close() throws IOException { - shutdownRefreshExecutor(); - if (closeables != null) { closeables.close(); } } - private void shutdownRefreshExecutor() { - if (refreshExecutor != null) { - ScheduledExecutorService service = refreshExecutor; - this.refreshExecutor = null; - - List tasks = service.shutdownNow(); - tasks.forEach( - task -> { - if (task instanceof Future) { - ((Future) task).cancel(true); - } - }); - - try { - if (!service.awaitTermination(1, TimeUnit.MINUTES)) { - LOG.warn("Timed out waiting for refresh executor to terminate"); - } - } catch (InterruptedException e) { - LOG.warn("Interrupted while waiting for refresh executor to terminate", e); - Thread.currentThread().interrupt(); - } - } - } - private class Builder implements Catalog.TableBuilder { private final TableIdentifier ident; private final Schema schema; @@ -812,20 +701,25 @@ public Table create() { .setProperties(propertiesBuilder.build()) .build(); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); LoadTableResponse response = - client.post( - paths.tables(ident.namespace()), - request, - LoadTableResponse.class, - headers(context), - ErrorHandlers.tableErrorHandler()); - - AuthSession session = tableSession(response.config(), session(context)); + client + .withAuthSession(contextualSession) + .post( + paths.tables(ident.namespace()), + request, + LoadTableResponse.class, + Map.of(), + ErrorHandlers.tableErrorHandler()); + + Map tableConf = response.config(); + AuthSession tableSession = authManager.tableSession(ident, tableConf, contextualSession); + RESTClient tableClient = client.withAuthSession(tableSession); RESTTableOperations ops = new RESTTableOperations( - client, + tableClient, paths.table(ident), - session::headers, + Map::of, tableFileIO(context, response.config()), response.tableMetadata(), endpoints); @@ -833,7 +727,7 @@ public Table create() { trackFileIO(ops); return new BaseTable( - ops, fullTableName(ident), metricsReporter(paths.metrics(ident), session::headers)); + ops, fullTableName(ident), metricsReporter(paths.metrics(ident), tableClient)); } @Override @@ -842,14 +736,17 @@ public Transaction createTransaction() { LoadTableResponse response = stageCreate(); String fullName = fullTableName(ident); - AuthSession session = tableSession(response.config(), session(context)); + Map tableConf = response.config(); + AuthSession parent = authManager.contextualSession(context, catalogAuth); + AuthSession tableSession = authManager.tableSession(ident, tableConf, parent); TableMetadata meta = response.tableMetadata(); + RESTClient tableClient = client.withAuthSession(tableSession); RESTTableOperations ops = new RESTTableOperations( - client, + tableClient, paths.table(ident), - session::headers, + Map::of, tableFileIO(context, response.config()), RESTTableOperations.UpdateType.CREATE, createChanges(meta), @@ -859,7 +756,7 @@ public Transaction createTransaction() { trackFileIO(ops); return Transactions.createTableTransaction( - fullName, ops, meta, metricsReporter(paths.metrics(ident), session::headers)); + fullName, ops, meta, metricsReporter(paths.metrics(ident), tableClient)); } @Override @@ -872,7 +769,9 @@ public Transaction replaceTransaction() { LoadTableResponse response = loadInternal(context, ident, snapshotMode); String fullName = fullTableName(ident); - AuthSession session = tableSession(response.config(), session(context)); + Map tableConf = response.config(); + AuthSession parent = authManager.contextualSession(context, catalogAuth); + AuthSession tableSession = authManager.tableSession(ident, tableConf, parent); TableMetadata base = response.tableMetadata(); Map tableProperties = propertiesBuilder.build(); @@ -904,11 +803,12 @@ public Transaction replaceTransaction() { changes.add(new MetadataUpdate.SetDefaultSortOrder(replacement.defaultSortOrderId())); } + RESTClient tableClient = client.withAuthSession(tableSession); RESTTableOperations ops = new RESTTableOperations( - client, + tableClient, paths.table(ident), - session::headers, + Map::of, tableFileIO(context, response.config()), RESTTableOperations.UpdateType.REPLACE, changes.build(), @@ -918,7 +818,7 @@ public Transaction replaceTransaction() { trackFileIO(ops); return Transactions.replaceTableTransaction( - fullName, ops, replacement, metricsReporter(paths.metrics(ident), session::headers)); + fullName, ops, replacement, metricsReporter(paths.metrics(ident), tableClient)); } @Override @@ -950,12 +850,15 @@ private LoadTableResponse stageCreate() { .setProperties(tableProperties) .build(); - return client.post( - paths.tables(ident.namespace()), - request, - LoadTableResponse.class, - headers(context), - ErrorHandlers.tableErrorHandler()); + AuthSession authSession = authManager.contextualSession(context, catalogAuth); + return client + .withAuthSession(authSession) + .post( + paths.tables(ident.namespace()), + request, + LoadTableResponse.class, + Map.of(), + ErrorHandlers.tableErrorHandler()); } } @@ -1021,26 +924,8 @@ private FileIO tableFileIO(SessionContext context, Map config) { return newFileIO(context, fullConf); } - private AuthSession tableSession(Map tableConf, AuthSession parent) { - Map credentials = Maps.newHashMapWithExpectedSize(tableConf.size()); - for (String prop : tableConf.keySet()) { - if (TABLE_SESSION_ALLOW_LIST.contains(prop)) { - credentials.put(prop, tableConf.get(prop)); - } - } - - Pair> newSession = newSession(credentials, tableConf, parent); - if (null == newSession) { - return parent; - } - - AuthSession session = tableSessions.get(newSession.first(), id -> newSession.second().get()); - - return session != null ? session : parent; - } - private static ConfigResponse fetchConfig( - RESTClient client, Map headers, Map properties) { + RESTClient client, AuthSession initialAuth, Map properties) { // send the client's warehouse location to the service to keep in sync // this is needed for cases where the warehouse is configured client side, but may be used on // the server side, @@ -1054,76 +939,18 @@ private static ConfigResponse fetchConfig( } ConfigResponse configResponse = - client.get( - ResourcePaths.config(), - queryParams.build(), - ConfigResponse.class, - headers, - ErrorHandlers.defaultErrorHandler()); + client + .withAuthSession(initialAuth) + .get( + ResourcePaths.config(), + queryParams.build(), + ConfigResponse.class, + configHeaders(properties), + ErrorHandlers.defaultErrorHandler()); configResponse.validate(); return configResponse; } - private Pair> newSession( - Map credentials, Map properties, AuthSession parent) { - if (credentials != null) { - // use the bearer token without exchanging - if (credentials.containsKey(OAuth2Properties.TOKEN)) { - return Pair.of( - credentials.get(OAuth2Properties.TOKEN), - () -> - AuthSession.fromAccessToken( - client, - tokenRefreshExecutor(name()), - credentials.get(OAuth2Properties.TOKEN), - expiresAtMillis(properties), - parent)); - } - - if (credentials.containsKey(OAuth2Properties.CREDENTIAL)) { - // fetch a token using the client credentials flow - return Pair.of( - credentials.get(OAuth2Properties.CREDENTIAL), - () -> - AuthSession.fromCredential( - client, - tokenRefreshExecutor(name()), - credentials.get(OAuth2Properties.CREDENTIAL), - parent)); - } - - for (String tokenType : TOKEN_PREFERENCE_ORDER) { - if (credentials.containsKey(tokenType)) { - // exchange the token for an access token using the token exchange flow - return Pair.of( - credentials.get(tokenType), - () -> - AuthSession.fromTokenExchange( - client, - tokenRefreshExecutor(name()), - credentials.get(tokenType), - tokenType, - parent)); - } - } - } - - return null; - } - - private Long expiresAtMillis(Map properties) { - if (properties.containsKey(OAuth2Properties.TOKEN_EXPIRES_IN_MS)) { - long expiresInMillis = - PropertyUtil.propertyAsLong( - properties, - OAuth2Properties.TOKEN_EXPIRES_IN_MS, - OAuth2Properties.TOKEN_EXPIRES_IN_MS_DEFAULT); - return System.currentTimeMillis() + expiresInMillis; - } else { - return null; - } - } - private void checkIdentifierIsValid(TableIdentifier tableIdentifier) { if (tableIdentifier.namespace().isEmpty()) { throw new NoSuchTableException("Invalid table identifier: %s", tableIdentifier); @@ -1146,25 +973,6 @@ private static Map configHeaders(Map properties) return RESTUtil.extractPrefixMap(properties, "header."); } - private static Cache newSessionCache(Map properties) { - long expirationIntervalMs = - PropertyUtil.propertyAsLong( - properties, - CatalogProperties.AUTH_SESSION_TIMEOUT_MS, - CatalogProperties.AUTH_SESSION_TIMEOUT_MS_DEFAULT); - - return Caffeine.newBuilder() - .expireAfterAccess(Duration.ofMillis(expirationIntervalMs)) - .removalListener( - (RemovalListener) - (id, auth, cause) -> { - if (auth != null) { - auth.stopRefreshing(); - } - }) - .build(); - } - public void commitTransaction(SessionContext context, List commits) { Endpoint.check(endpoints, Endpoint.V1_COMMIT_TRANSACTION); List tableChanges = Lists.newArrayListWithCapacity(commits.size()); @@ -1174,12 +982,15 @@ public void commitTransaction(SessionContext context, List commits) UpdateTableRequest.create(commit.identifier(), commit.requirements(), commit.updates())); } - client.post( - paths.commitTransaction(), - new CommitTransactionRequest(tableChanges), - null, - headers(context), - ErrorHandlers.tableCommitHandler()); + AuthSession authSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(authSession) + .post( + paths.commitTransaction(), + new CommitTransactionRequest(tableChanges), + null, + Map.of(), + ErrorHandlers.tableCommitHandler()); } @Override @@ -1198,13 +1009,16 @@ public List listViews(SessionContext context, Namespace namespa do { queryParams.put("pageToken", pageToken); + AuthSession authSession = authManager.contextualSession(context, catalogAuth); ListTablesResponse response = - client.get( - paths.views(namespace), - queryParams, - ListTablesResponse.class, - headers(context), - ErrorHandlers.namespaceErrorHandler()); + client + .withAuthSession(authSession) + .get( + paths.views(namespace), + queryParams, + ListTablesResponse.class, + Map.of(), + ErrorHandlers.namespaceErrorHandler()); pageToken = response.nextPageToken(); views.addAll(response.identifiers()); } while (pageToken != null); @@ -1224,19 +1038,23 @@ public View loadView(SessionContext context, TableIdentifier identifier) { checkViewIdentifierIsValid(identifier); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); LoadViewResponse response = - client.get( - paths.view(identifier), - LoadViewResponse.class, - headers(context), - ErrorHandlers.viewErrorHandler()); - - AuthSession session = tableSession(response.config(), session(context)); + client + .withAuthSession(contextualSession) + .get( + paths.view(identifier), + LoadViewResponse.class, + Map.of(), + ErrorHandlers.viewErrorHandler()); + + Map tableConf = response.config(); + AuthSession session = authManager.tableSession(identifier, tableConf, contextualSession); ViewMetadata metadata = response.metadata(); RESTViewOperations ops = new RESTViewOperations( - client, paths.view(identifier), session::headers, metadata, endpoints); + client.withAuthSession(session), paths.view(identifier), Map::of, metadata, endpoints); return new BaseView(ops, ViewUtil.fullViewName(name(), identifier)); } @@ -1252,8 +1070,10 @@ public boolean dropView(SessionContext context, TableIdentifier identifier) { checkViewIdentifierIsValid(identifier); try { - client.delete( - paths.view(identifier), null, headers(context), ErrorHandlers.viewErrorHandler()); + AuthSession authSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(authSession) + .delete(paths.view(identifier), null, Map.of(), ErrorHandlers.viewErrorHandler()); return true; } catch (NoSuchViewException e) { return false; @@ -1269,8 +1089,10 @@ public void renameView(SessionContext context, TableIdentifier from, TableIdenti RenameTableRequest request = RenameTableRequest.builder().withSource(from).withDestination(to).build(); - client.post( - paths.renameView(), request, null, headers(context), ErrorHandlers.viewErrorHandler()); + AuthSession authSession = authManager.contextualSession(context, catalogAuth); + client + .withAuthSession(authSession) + .post(paths.renameView(), request, null, Map.of(), ErrorHandlers.viewErrorHandler()); } private class RESTViewBuilder implements ViewBuilder { @@ -1361,18 +1183,26 @@ public View create() { .properties(properties) .build(); + AuthSession contextualSession = authManager.contextualSession(context, catalogAuth); LoadViewResponse response = - client.post( - paths.views(identifier.namespace()), - request, - LoadViewResponse.class, - headers(context), - ErrorHandlers.viewErrorHandler()); - - AuthSession session = tableSession(response.config(), session(context)); + client + .withAuthSession(contextualSession) + .post( + paths.views(identifier.namespace()), + request, + LoadViewResponse.class, + Map.of(), + ErrorHandlers.viewErrorHandler()); + + Map tableConf = response.config(); + AuthSession session = authManager.tableSession(identifier, tableConf, contextualSession); RESTViewOperations ops = new RESTViewOperations( - client, paths.view(identifier), session::headers, response.metadata(), endpoints); + client.withAuthSession(session), + paths.view(identifier), + Map::of, + response.metadata(), + endpoints); return new BaseView(ops, ViewUtil.fullViewName(name(), identifier)); } @@ -1404,11 +1234,14 @@ private LoadViewResponse loadView() { "Unable to load view %s.%s: Server does not support endpoint %s", name(), identifier, Endpoint.V1_LOAD_VIEW)); - return client.get( - paths.view(identifier), - LoadViewResponse.class, - headers(context), - ErrorHandlers.viewErrorHandler()); + AuthSession authSession = authManager.contextualSession(context, catalogAuth); + return client + .withAuthSession(authSession) + .get( + paths.view(identifier), + LoadViewResponse.class, + Map.of(), + ErrorHandlers.viewErrorHandler()); } private View replace(LoadViewResponse response) { @@ -1449,10 +1282,16 @@ private View replace(LoadViewResponse response) { ViewMetadata replacement = builder.build(); - AuthSession session = tableSession(response.config(), session(context)); + Map tableConf = response.config(); + AuthSession parent = authManager.contextualSession(context, catalogAuth); + AuthSession session = authManager.tableSession(identifier, tableConf, parent); RESTViewOperations ops = new RESTViewOperations( - client, paths.view(identifier), session::headers, metadata, endpoints); + client.withAuthSession(session), + paths.view(identifier), + Map::of, + metadata, + endpoints); ops.commit(metadata, replacement); diff --git a/core/src/main/java/org/apache/iceberg/rest/RESTUtil.java b/core/src/main/java/org/apache/iceberg/rest/RESTUtil.java index fab01162cad7..4074fbdba32b 100644 --- a/core/src/main/java/org/apache/iceberg/rest/RESTUtil.java +++ b/core/src/main/java/org/apache/iceberg/rest/RESTUtil.java @@ -18,13 +18,18 @@ */ package org.apache.iceberg.rest; +import com.fasterxml.jackson.core.JsonProcessingException; import java.io.UncheckedIOException; import java.io.UnsupportedEncodingException; +import java.net.URI; +import java.net.URISyntaxException; import java.net.URLDecoder; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.util.Map; +import org.apache.hc.core5.net.URIBuilder; import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.exceptions.RESTException; import org.apache.iceberg.relocated.com.google.common.base.Joiner; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.base.Splitter; @@ -215,4 +220,45 @@ public static Namespace decodeNamespace(String encodedNs) { return Namespace.of(levels); } + + /** Builds a request URI from a base URI and an {@link HTTPRequest}. */ + public static URI buildRequestUri(HTTPRequest request) { + // if full path is provided, use the input path as path + String path = request.path(); + if (path.startsWith("/")) { + throw new RESTException( + "Received a malformed path for a REST request: %s. Paths should not start with /", path); + } + String fullPath = + (path.startsWith("https://") || path.startsWith("http://")) + ? path + : String.format("%s/%s", request.baseUri(), path); + try { + URIBuilder builder = new URIBuilder(stripTrailingSlash(fullPath)); + request.queryParameters().forEach(builder::addParameter); + return builder.build(); + } catch (URISyntaxException e) { + throw new RESTException( + "Failed to create request URI from base %s, params %s", + fullPath, request.queryParameters()); + } + } + + /** + * Encodes the body of an HTTP request as a String. By convention, maps are encoded as form data + * and other objects are encoded as JSON. + */ + public static String encodeRequestBody(HTTPRequest request) { + Object body = request.body(); + if (body instanceof Map) { + return encodeFormData((Map) body); + } else if (body != null) { + try { + return request.mapper().writeValueAsString(body); + } catch (JsonProcessingException e) { + throw new RESTException(e, "Failed to encode request body: %s", body); + } + } + return null; + } } diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthConfig.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthConfig.java index 275884e1184a..d619056858c4 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/AuthConfig.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthConfig.java @@ -21,6 +21,7 @@ import java.util.Map; import javax.annotation.Nullable; import org.apache.iceberg.rest.ResourcePaths; +import org.apache.iceberg.util.PropertyUtil; import org.immutables.value.Value; /** @@ -28,8 +29,8 @@ * org.apache.iceberg.rest.auth.OAuth2Util.AuthSession}. */ @Value.Style(redactedMask = "****") -@SuppressWarnings("ImmutablesStyle") @Value.Immutable +@SuppressWarnings({"ImmutablesStyle", "SafeLoggingPropagation"}) public interface AuthConfig { @Nullable @Value.Redacted @@ -47,7 +48,7 @@ default String scope() { return OAuth2Properties.CATALOG_SCOPE; } - @Value.Lazy + @Value.Default @Nullable default Long expiresAtMillis() { return OAuth2Util.expiresAtMillis(token()); @@ -69,4 +70,42 @@ default String oauth2ServerUri() { static ImmutableAuthConfig.Builder builder() { return ImmutableAuthConfig.builder(); } + + static AuthConfig fromProperties(Map properties) { + return builder() + .credential(properties.get(OAuth2Properties.CREDENTIAL)) + .token(properties.get(OAuth2Properties.TOKEN)) + .scope(properties.getOrDefault(OAuth2Properties.SCOPE, OAuth2Properties.CATALOG_SCOPE)) + .oauth2ServerUri( + properties.getOrDefault(OAuth2Properties.OAUTH2_SERVER_URI, ResourcePaths.tokens())) + .optionalOAuthParams(OAuth2Util.buildOptionalParam(properties)) + .keepRefreshed( + PropertyUtil.propertyAsBoolean( + properties, + OAuth2Properties.TOKEN_REFRESH_ENABLED, + OAuth2Properties.TOKEN_REFRESH_ENABLED_DEFAULT)) + .expiresAtMillis(expiresAtMillis(properties)) + .build(); + } + + private static Long expiresAtMillis(Map props) { + Long expiresAtMillis = null; + + if (props.containsKey(OAuth2Properties.TOKEN)) { + expiresAtMillis = OAuth2Util.expiresAtMillis(props.get(OAuth2Properties.TOKEN)); + } + + if (expiresAtMillis == null) { + if (props.containsKey(OAuth2Properties.TOKEN_EXPIRES_IN_MS)) { + long millis = + PropertyUtil.propertyAsLong( + props, + OAuth2Properties.TOKEN_EXPIRES_IN_MS, + OAuth2Properties.TOKEN_EXPIRES_IN_MS_DEFAULT); + expiresAtMillis = System.currentTimeMillis() + millis; + } + } + + return expiresAtMillis; + } } diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthManager.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthManager.java new file mode 100644 index 000000000000..b99556471e38 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthManager.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import java.util.Map; +import org.apache.iceberg.catalog.SessionCatalog; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.rest.RESTClient; + +/** + * Manager for authentication sessions. This interface is used to create sessions for the catalog, + * the tables/views, and any other context that requires authentication. + * + *

Managers are usually stateful and may require initialization and cleanup. The manager is + * created by the catalog and is closed when the catalog is closed. + */ +public interface AuthManager extends AutoCloseable { + + /** + * Returns a temporary session to use for contacting the configuration endpoint only. Note that + * the returned session will be closed after the configuration endpoint is contacted, and should + * not be cached. + * + *

The provided REST client is a short-lived client; it should only be used to fetch initial + * credentials, if required, and must be discarded after that. + * + *

This method cannot return null. By default, it returns the catalog session. + */ + default AuthSession initSession(RESTClient initClient, Map properties) { + return catalogSession(initClient, properties); + } + + /** + * Returns a long-lived session whose lifetime is tied to the owning catalog. This session serves + * as the parent session for all other sessions (contextual and table-specific). It is closed when + * the owning catalog is closed. + * + *

The provided REST client is a long-lived, shared client; if required, implementors may store + * it and reuse it for all subsequent requests to the authorization server, e.g. for renewing or + * refreshing credentials. It is not necessary to close it when {@link #close()} is called. + * + *

This method cannot return null. + * + *

It is not required to cache the returned session internally, as the catalog will keep it + * alive for the lifetime of the catalog. + */ + AuthSession catalogSession(RESTClient sharedClient, Map properties); + + /** + * Returns a session for a specific context. + * + *

If the context requires a specific {@link AuthSession}, this method should return a new + * {@link AuthSession} instance, otherwise it should return the parent session. + * + *

This method cannot return null. By default, it returns the parent session. + * + *

Implementors should cache contextual sessions internally, as the catalog will not cache + * them. Also, the owning catalog never closes contextual sessions; implementations should manage + * their lifecycle themselves and close them when they are no longer needed. + */ + default AuthSession contextualSession(SessionCatalog.SessionContext context, AuthSession parent) { + return parent; + } + + /** + * Returns a new session targeting a specific table or view. The properties are the ones returned + * by the table/view endpoint. + * + *

If the table or view requires a specific {@link AuthSession}, this method should return a + * new {@link AuthSession} instance, otherwise it should return the parent session. + * + *

This method cannot return null. By default, it returns the parent session. + * + *

Implementors should cache table sessions internally, as the catalog will not cache them. + * Also, the owning catalog never closes table sessions; implementations should manage their + * lifecycle themselves and close them when they are no longer needed. + */ + default AuthSession tableSession( + TableIdentifier table, Map properties, AuthSession parent) { + return parent; + } + + /** + * Closes the manager and releases any resources. + * + *

This method is called when the owning catalog is closed. + */ + @Override + default void close() { + // Do nothing + } +} diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthManagers.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthManagers.java new file mode 100644 index 000000000000..b34b8e6e9d66 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthManagers.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import java.util.Locale; +import java.util.Map; +import org.apache.iceberg.common.DynConstructors; +import org.apache.iceberg.util.PropertyUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class AuthManagers { + + private static final Logger LOG = LoggerFactory.getLogger(AuthManagers.class); + + /** Old property name for enabling SigV4 authentication. */ + private static final String SIGV4_ENABLED_LEGACY = "rest.sigv4-enabled"; + + private AuthManagers() {} + + public static AuthManager loadAuthManager(String name, Map properties) { + + if (properties.containsKey(SIGV4_ENABLED_LEGACY)) { + LOG.warn( + "The property {} is deprecated and will be removed in a future release. " + + "Please use the property {}={} instead.", + SIGV4_ENABLED_LEGACY, + AuthProperties.AUTH_TYPE, + AuthProperties.AUTH_TYPE_SIGV4); + } + + String authType; + if (PropertyUtil.propertyAsBoolean(properties, SIGV4_ENABLED_LEGACY, false)) { + authType = AuthProperties.AUTH_TYPE_SIGV4; + } else { + authType = properties.get(AuthProperties.AUTH_TYPE); + if (authType == null) { + boolean hasCredential = properties.containsKey(OAuth2Properties.CREDENTIAL); + boolean hasToken = properties.containsKey(OAuth2Properties.TOKEN); + if (hasCredential || hasToken) { + LOG.warn( + "Inferring {}={} since property {} was provided. " + + "Please explicitly set {} to avoid this warning.", + AuthProperties.AUTH_TYPE, + AuthProperties.AUTH_TYPE_OAUTH2, + hasCredential ? OAuth2Properties.CREDENTIAL : OAuth2Properties.TOKEN, + AuthProperties.AUTH_TYPE); + authType = AuthProperties.AUTH_TYPE_OAUTH2; + } else { + authType = AuthProperties.AUTH_TYPE_NONE; + } + } + } + + String impl; + switch (authType.toLowerCase(Locale.ROOT)) { + case AuthProperties.AUTH_TYPE_NONE: + impl = AuthProperties.AUTH_MANAGER_IMPL_NONE; + break; + case AuthProperties.AUTH_TYPE_BASIC: + impl = AuthProperties.AUTH_MANAGER_IMPL_BASIC; + break; + case AuthProperties.AUTH_TYPE_SIGV4: + impl = AuthProperties.AUTH_MANAGER_IMPL_SIGV4; + break; + case AuthProperties.AUTH_TYPE_OAUTH2: + impl = AuthProperties.AUTH_MANAGER_IMPL_OAUTH2; + break; + default: + impl = authType; + } + + LOG.info("Loading AuthManager implementation: {}", impl); + DynConstructors.Ctor ctor; + try { + ctor = + DynConstructors.builder(AuthManager.class) + .loader(AuthManagers.class.getClassLoader()) + .impl(impl, String.class) // with name + .impl(impl) // without name + .buildChecked(); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException( + String.format( + "Cannot initialize AuthManager implementation %s: %s", impl, e.getMessage()), + e); + } + + AuthManager authManager; + try { + authManager = ctor.newInstance(name); + } catch (ClassCastException e) { + throw new IllegalArgumentException( + String.format("Cannot initialize AuthManager, %s does not implement AuthManager", impl), + e); + } + + return authManager; + } +} diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthProperties.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthProperties.java new file mode 100644 index 000000000000..61b509e01fe7 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthProperties.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +public final class AuthProperties { + + private AuthProperties() {} + + public static final String AUTH_TYPE = "rest.auth.type"; + + public static final String AUTH_TYPE_NONE = "none"; + public static final String AUTH_TYPE_BASIC = "basic"; + public static final String AUTH_TYPE_OAUTH2 = "oauth2"; + public static final String AUTH_TYPE_SIGV4 = "sigv4"; + + public static final String AUTH_MANAGER_IMPL_NONE = + "org.apache.iceberg.rest.auth.NoopAuthManager"; + public static final String AUTH_MANAGER_IMPL_BASIC = + "org.apache.iceberg.rest.auth.BasicAuthManager"; + public static final String AUTH_MANAGER_IMPL_OAUTH2 = + "org.apache.iceberg.rest.auth.OAuth2Manager"; + public static final String AUTH_MANAGER_IMPL_SIGV4 = + "org.apache.iceberg.aws.RESTSigV4AuthManager"; + + public static final String BASIC_USERNAME = "rest.auth.basic.username"; + public static final String BASIC_PASSWORD = "rest.auth.basic.password"; +} diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthSession.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthSession.java new file mode 100644 index 000000000000..169a53d7a8f2 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthSession.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import org.apache.iceberg.rest.HTTPRequest; + +/** + * An authentication session that can be used to authenticate outgoing HTTP requests. + * + *

Authentication sessions are usually immutable, but may hold resources that need to be released + * when the session is no longer needed. Implementations should override {@link #close()} to release + * any resources. + */ +public interface AuthSession extends AutoCloseable { + + /** An empty session that does nothing. */ + AuthSession EMPTY = request -> request; + + /** + * Authenticates the given request and returns a new request with the necessary authentication. + */ + HTTPRequest authenticate(HTTPRequest request); + + /** + * Closes the session and releases any resources. This method is called when the session is no + * longer needed. Note that since sessions may be cached, this method may not be called + * immediately after the session is no longer needed, but rather when the session is evicted from + * the cache, or the cache itself is closed. + */ + @Override + default void close() { + // Do nothing + } +} diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthSessionCache.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthSessionCache.java new file mode 100644 index 000000000000..d984668a4c83 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthSessionCache.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import com.github.benmanes.caffeine.cache.RemovalListener; +import java.time.Duration; +import java.util.function.Supplier; + +/** A cache for {@link AuthSession} instances. */ +public class AuthSessionCache implements AutoCloseable { + + private final Duration sessionTimeout; + private volatile Cache sessionCache; + + public AuthSessionCache(Duration sessionTimeout) { + this.sessionTimeout = sessionTimeout; + } + + @SuppressWarnings("unchecked") + public T cachedSession(String key, Supplier loader) { + return (T) sessionCache().get(key, k -> loader.get()); + } + + @Override + public void close() { + Cache cache = sessionCache; + this.sessionCache = null; + if (cache != null) { + cache.invalidateAll(); + cache.cleanUp(); + } + } + + private Cache sessionCache() { + if (sessionCache == null) { + synchronized (this) { + if (sessionCache == null) { + this.sessionCache = newSessionCache(sessionTimeout); + } + } + } + return sessionCache; + } + + private static Cache newSessionCache(Duration sessionTimeout) { + return Caffeine.newBuilder() + .expireAfterAccess(sessionTimeout) + .removalListener( + (RemovalListener) + (id, auth, cause) -> { + if (auth != null) { + auth.close(); + } + }) + .build(); + } +} diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/BasicAuthManager.java b/core/src/main/java/org/apache/iceberg/rest/auth/BasicAuthManager.java new file mode 100644 index 000000000000..fb970fcc5713 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/BasicAuthManager.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.rest.RESTClient; + +/** An auth manager that adds static BASIC authentication data to outgoing HTTP requests. */ +public final class BasicAuthManager implements AuthManager { + + @Override + public AuthSession catalogSession(RESTClient sharedClient, Map properties) { + Preconditions.checkArgument( + properties.containsKey(AuthProperties.BASIC_USERNAME), + "Invalid username: missing required property %s", + AuthProperties.BASIC_USERNAME); + Preconditions.checkArgument( + properties.containsKey(AuthProperties.BASIC_PASSWORD), + "Invalid password: missing required property %s", + AuthProperties.BASIC_PASSWORD); + String username = properties.get(AuthProperties.BASIC_USERNAME); + String password = properties.get(AuthProperties.BASIC_PASSWORD); + String credentials = username + ":" + password; + return DefaultAuthSession.of(OAuth2Util.basicAuthHeaders(credentials)); + } +} diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/DefaultAuthSession.java b/core/src/main/java/org/apache/iceberg/rest/auth/DefaultAuthSession.java new file mode 100644 index 000000000000..70d12edc0215 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/DefaultAuthSession.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import java.util.Map; +import org.apache.iceberg.rest.HTTPRequest; +import org.immutables.value.Value; + +/** + * Default implementation of {@link AuthSession}. It authenticates requests by setting the provided + * headers on the request. + * + *

Most {@link AuthManager} implementations should make use of this class, unless they need to + * retain state when creating sessions, or if they need to modify the request in a different way. + */ +@Value.Style(redactedMask = "****") +@Value.Immutable +@SuppressWarnings({"ImmutablesStyle", "SafeLoggingPropagation"}) +public interface DefaultAuthSession extends AuthSession { + + /** Headers containing authentication data to set on the request. */ + @Value.Redacted + Map headers(); + + @Override + default HTTPRequest authenticate(HTTPRequest request) { + return request.putHeadersIfAbsent(headers()); + } + + static DefaultAuthSession of(String name, String value) { + return ImmutableDefaultAuthSession.builder().putHeaders(name, value).build(); + } + + static DefaultAuthSession of(Map authHeaders) { + return ImmutableDefaultAuthSession.builder().putAllHeaders(authHeaders).build(); + } +} diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/NoopAuthManager.java b/core/src/main/java/org/apache/iceberg/rest/auth/NoopAuthManager.java new file mode 100644 index 000000000000..2712f982c262 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/NoopAuthManager.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import java.util.Map; +import org.apache.iceberg.rest.RESTClient; + +/** An auth manager that does not add any authentication data to outgoing HTTP requests. */ +public class NoopAuthManager implements AuthManager { + + @Override + public AuthSession catalogSession(RESTClient sharedClient, Map properties) { + return AuthSession.EMPTY; + } +} diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java new file mode 100644 index 000000000000..4622f09d3cb9 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.catalog.SessionCatalog; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.rest.RESTClient; +import org.apache.iceberg.rest.RESTUtil; +import org.apache.iceberg.rest.ResourcePaths; +import org.apache.iceberg.rest.responses.OAuthTokenResponse; +import org.apache.iceberg.util.PropertyUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressWarnings("unused") // loaded by reflection +public class OAuth2Manager extends RefreshingAuthManager { + + private static final Logger LOG = LoggerFactory.getLogger(OAuth2Manager.class); + + private static final List TOKEN_PREFERENCE_ORDER = + ImmutableList.of( + OAuth2Properties.ID_TOKEN_TYPE, + OAuth2Properties.ACCESS_TOKEN_TYPE, + OAuth2Properties.JWT_TOKEN_TYPE, + OAuth2Properties.SAML2_TOKEN_TYPE, + OAuth2Properties.SAML1_TOKEN_TYPE); + + // Auth-related properties that are allowed to be passed to the table session + private static final Set TABLE_SESSION_ALLOW_LIST = + ImmutableSet.builder() + .add(OAuth2Properties.TOKEN) + .addAll(TOKEN_PREFERENCE_ORDER) + .build(); + + private RESTClient client; + private long startTimeMillis; + private OAuthTokenResponse authResponse; + private AuthSessionCache sessionCache; + + public OAuth2Manager(String name) { + super(name + "-token-refresh"); + } + + @Override + public OAuth2Util.AuthSession initSession(RESTClient initClient, Map properties) { + warnIfDeprecatedTokenEndpointUsed(properties); + AuthConfig config = AuthConfig.fromProperties(properties); + Map headers = OAuth2Util.authHeaders(config.token()); + OAuth2Util.AuthSession session = new OAuth2Util.AuthSession(headers, config); + if (config.credential() != null && !config.credential().isEmpty()) { + // keep track of the start time for token refresh + this.startTimeMillis = System.currentTimeMillis(); + this.authResponse = + OAuth2Util.fetchToken( + initClient, + headers, + config.credential(), + config.scope(), + config.oauth2ServerUri(), + config.optionalOAuthParams()); + return OAuth2Util.AuthSession.fromTokenResponse( + initClient, null, authResponse, startTimeMillis, session); + } else if (config.token() != null) { + return OAuth2Util.AuthSession.fromAccessToken( + initClient, null, config.token(), null, session); + } + return session; + } + + @Override + public OAuth2Util.AuthSession catalogSession( + RESTClient sharedClient, Map properties) { + this.client = sharedClient; + this.sessionCache = new AuthSessionCache(sessionTimeout(properties)); + AuthConfig config = AuthConfig.fromProperties(properties); + Map headers = OAuth2Util.authHeaders(config.token()); + OAuth2Util.AuthSession session = new OAuth2Util.AuthSession(headers, config); + keepRefreshed(config.keepRefreshed()); + // authResponse comes from the init phase + if (authResponse != null) { + return OAuth2Util.AuthSession.fromTokenResponse( + client, refreshExecutor(), authResponse, startTimeMillis, session); + } else if (config.token() != null) { + return OAuth2Util.AuthSession.fromAccessToken( + client, refreshExecutor(), config.token(), config.expiresAtMillis(), session); + } + return session; + } + + @Override + public OAuth2Util.AuthSession contextualSession( + SessionCatalog.SessionContext context, AuthSession parent) { + return maybeCreateChildSession( + context.credentials(), + context.properties(), + ignored -> context.sessionId(), + (OAuth2Util.AuthSession) parent); + } + + @Override + public OAuth2Util.AuthSession tableSession( + TableIdentifier table, Map properties, AuthSession parent) { + return maybeCreateChildSession( + Maps.filterKeys(properties, TABLE_SESSION_ALLOW_LIST::contains), + properties, + properties::get, + (OAuth2Util.AuthSession) parent); + } + + @Override + public void close() { + try { + super.close(); + } finally { + AuthSessionCache cache = sessionCache; + this.sessionCache = null; + if (cache != null) { + cache.close(); + } + } + } + + protected OAuth2Util.AuthSession maybeCreateChildSession( + Map credentials, + Map properties, + Function cacheKeyFunc, + OAuth2Util.AuthSession parent) { + if (credentials != null) { + // use the bearer token without exchanging + if (credentials.containsKey(OAuth2Properties.TOKEN)) { + String token = credentials.get(OAuth2Properties.TOKEN); + return sessionCache.cachedSession( + cacheKeyFunc.apply(OAuth2Properties.TOKEN), + () -> newSessionFromAccessToken(token, properties, parent)); + } + + if (credentials.containsKey(OAuth2Properties.CREDENTIAL)) { + // fetch a token using the client credentials flow + String credential = credentials.get(OAuth2Properties.CREDENTIAL); + return sessionCache.cachedSession( + cacheKeyFunc.apply(OAuth2Properties.CREDENTIAL), + () -> newSessionFromCredential(credential, parent)); + } + + for (String tokenType : TOKEN_PREFERENCE_ORDER) { + if (credentials.containsKey(tokenType)) { + // exchange the token for an access token using the token exchange flow + String token = credentials.get(tokenType); + return sessionCache.cachedSession( + cacheKeyFunc.apply(tokenType), + () -> newSessionFromTokenExchange(token, tokenType, parent)); + } + } + } + + return parent; + } + + protected OAuth2Util.AuthSession newSessionFromAccessToken( + String token, Map properties, OAuth2Util.AuthSession parent) { + Long expiresAtMillis = AuthConfig.fromProperties(properties).expiresAtMillis(); + return OAuth2Util.AuthSession.fromAccessToken( + client, refreshExecutor(), token, expiresAtMillis, parent); + } + + protected OAuth2Util.AuthSession newSessionFromCredential( + String credential, OAuth2Util.AuthSession parent) { + return OAuth2Util.AuthSession.fromCredential(client, refreshExecutor(), credential, parent); + } + + protected OAuth2Util.AuthSession newSessionFromTokenExchange( + String token, String tokenType, OAuth2Util.AuthSession parent) { + return OAuth2Util.AuthSession.fromTokenExchange( + client, refreshExecutor(), token, tokenType, parent); + } + + private static void warnIfDeprecatedTokenEndpointUsed(Map properties) { + if (usesDeprecatedTokenEndpoint(properties)) { + String credential = properties.get(OAuth2Properties.CREDENTIAL); + String initToken = properties.get(OAuth2Properties.TOKEN); + boolean hasCredential = credential != null && !credential.isEmpty(); + boolean hasInitToken = initToken != null; + if (hasInitToken || hasCredential) { + LOG.warn( + "Iceberg REST client is missing the OAuth2 server URI configuration and defaults to {}/{}. " + + "This automatic fallback will be removed in a future Iceberg release." + + "It is recommended to configure the OAuth2 endpoint using the '{}' property to be prepared. " + + "This warning will disappear if the OAuth2 endpoint is explicitly configured. " + + "See https://github.com/apache/iceberg/issues/10537", + RESTUtil.stripTrailingSlash(properties.get(CatalogProperties.URI)), + ResourcePaths.tokens(), + OAuth2Properties.OAUTH2_SERVER_URI); + } + } + } + + private static boolean usesDeprecatedTokenEndpoint(Map properties) { + if (properties.containsKey(OAuth2Properties.OAUTH2_SERVER_URI)) { + String oauth2ServerUri = properties.get(OAuth2Properties.OAUTH2_SERVER_URI); + boolean relativePath = !oauth2ServerUri.startsWith("http"); + boolean sameHost = oauth2ServerUri.startsWith(properties.get(CatalogProperties.URI)); + return relativePath || sameHost; + } + return true; + } + + private static Duration sessionTimeout(Map props) { + return Duration.ofMillis( + PropertyUtil.propertyAsLong( + props, + CatalogProperties.AUTH_SESSION_TIMEOUT_MS, + CatalogProperties.AUTH_SESSION_TIMEOUT_MS_DEFAULT)); + } +} diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java index 1757ae653cc9..e360c336b6c6 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java @@ -43,6 +43,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.rest.ErrorHandlers; +import org.apache.iceberg.rest.HTTPRequest; import org.apache.iceberg.rest.RESTClient; import org.apache.iceberg.rest.RESTUtil; import org.apache.iceberg.rest.ResourcePaths; @@ -451,18 +452,23 @@ static Long expiresAtMillis(String token) { } /** Class to handle authorization headers and token refresh. */ - public static class AuthSession { + public static class AuthSession implements org.apache.iceberg.rest.auth.AuthSession { private static int tokenRefreshNumRetries = 5; private static final long MAX_REFRESH_WINDOW_MILLIS = 300_000; // 5 minutes private static final long MIN_REFRESH_WAIT_MILLIS = 10; private volatile Map headers; private volatile AuthConfig config; - public AuthSession(Map baseHeaders, AuthConfig config) { - this.headers = RESTUtil.merge(baseHeaders, authHeaders(config.token())); + public AuthSession(Map headers, AuthConfig config) { + this.headers = ImmutableMap.copyOf(headers); this.config = config; } + @Override + public HTTPRequest authenticate(HTTPRequest request) { + return request.putHeadersIfAbsent(headers()); + } + public Map headers() { return headers; } @@ -487,6 +493,11 @@ public synchronized void stopRefreshing() { this.config = ImmutableAuthConfig.copyOf(config).withKeepRefreshed(false); } + @Override + public void close() { + stopRefreshing(); + } + public String credential() { return config.credential(); } @@ -647,7 +658,7 @@ public static AuthSession fromAccessToken( AuthSession parent) { AuthSession session = new AuthSession( - parent.headers(), + RESTUtil.merge(parent.headers(), authHeaders(token)), AuthConfig.builder() .from(parent.config()) .token(token) @@ -727,7 +738,7 @@ private static AuthSession fromTokenResponse( } AuthSession session = new AuthSession( - parent.headers(), + RESTUtil.merge(parent.headers(), authHeaders(response.token())), AuthConfig.builder() .from(parent.config()) .token(response.token()) diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/RefreshingAuthManager.java b/core/src/main/java/org/apache/iceberg/rest/auth/RefreshingAuthManager.java new file mode 100644 index 000000000000..2b443e0ea5c1 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/RefreshingAuthManager.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import java.util.List; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.iceberg.util.ThreadPools; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An {@link AuthManager} that provides machinery for refreshing authentication data asynchronously, + * using a background thread pool. + */ +public abstract class RefreshingAuthManager implements AuthManager { + + private static final Logger LOG = LoggerFactory.getLogger(RefreshingAuthManager.class); + + private final String executorNamePrefix; + private boolean keepRefreshed = true; + private volatile ScheduledExecutorService refreshExecutor; + + protected RefreshingAuthManager(String executorNamePrefix) { + this.executorNamePrefix = executorNamePrefix; + } + + public void keepRefreshed(boolean keep) { + this.keepRefreshed = keep; + } + + @Override + public void close() { + ScheduledExecutorService service = refreshExecutor; + this.refreshExecutor = null; + if (service != null) { + List tasks = service.shutdownNow(); + tasks.forEach( + task -> { + if (task instanceof Future) { + ((Future) task).cancel(true); + } + }); + + try { + if (!service.awaitTermination(1, TimeUnit.MINUTES)) { + LOG.warn("Timed out waiting for refresh executor to terminate"); + } + } catch (InterruptedException e) { + LOG.warn("Interrupted while waiting for refresh executor to terminate", e); + Thread.currentThread().interrupt(); + } + } + } + + @Nullable + protected ScheduledExecutorService refreshExecutor() { + if (!keepRefreshed) { + return null; + } + if (refreshExecutor == null) { + synchronized (this) { + if (refreshExecutor == null) { + this.refreshExecutor = ThreadPools.newScheduledPool(executorNamePrefix, 1); + } + } + } + return refreshExecutor; + } +} diff --git a/core/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java b/core/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java index 87b693e206ae..64317e55bc07 100644 --- a/core/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java +++ b/core/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java @@ -19,10 +19,12 @@ package org.apache.iceberg.rest; import java.io.IOException; +import java.net.URI; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.function.Consumer; +import java.util.function.Supplier; import java.util.stream.Collectors; import org.apache.iceberg.BaseTable; import org.apache.iceberg.BaseTransaction; @@ -50,6 +52,8 @@ import org.apache.iceberg.relocated.com.google.common.base.Splitter; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.requests.CommitTransactionRequest; import org.apache.iceberg.rest.requests.CreateNamespaceRequest; import org.apache.iceberg.rest.requests.CreateTableRequest; @@ -98,6 +102,8 @@ public class RESTCatalogAdapter implements RESTClient { private final SupportsNamespaces asNamespaceCatalog; private final ViewCatalog asViewCatalog; + private AuthSession authSession = AuthSession.EMPTY; + public RESTCatalogAdapter(Catalog catalog) { this.catalog = catalog; this.asNamespaceCatalog = @@ -105,13 +111,6 @@ public RESTCatalogAdapter(Catalog catalog) { this.asViewCatalog = catalog instanceof ViewCatalog ? (ViewCatalog) catalog : null; } - enum HTTPMethod { - GET, - HEAD, - POST, - DELETE - } - enum Route { TOKENS(HTTPMethod.POST, "v1/oauth/tokens", null, OAuthTokenResponse.class), SEPARATE_AUTH_TOKENS_URI( @@ -278,6 +277,12 @@ private static OAuthTokenResponse handleOAuthRequest(Object body) { } } + @Override + public RESTClient withAuthSession(AuthSession session) { + this.authSession = session; + return this; + } + @SuppressWarnings({"MethodLength", "checkstyle:CyclomaticComplexity"}) public T handleRequest( Route route, Map vars, Object body, Class responseType) { @@ -549,25 +554,41 @@ private static void commitTransaction(Catalog catalog, CommitTransactionRequest transactions.forEach(Transaction::commitTransaction); } - public T execute( + HTTPRequest buildRequest( HTTPMethod method, String path, Map queryParams, - Object body, - Class responseType, Map headers, - Consumer errorHandler) { + Object body) { + URI baseUri = URI.create("https://localhost:8080"); + ImmutableHTTPRequest.Builder builder = + ImmutableHTTPRequest.builder().baseUri(baseUri).method(method).path(path).body(body); + + if (queryParams != null) { + builder.queryParameters(queryParams); + } + + if (headers != null) { + headers.forEach((name, value) -> builder.putHeader(name, List.of(value))); + } + + return authSession.authenticate(builder.build()); + } + + T execute( + HTTPRequest request, + Class responseType, + Consumer errorHandler, + Consumer> responseHeaders) { ErrorResponse.Builder errorBuilder = ErrorResponse.builder(); - Pair> routeAndVars = Route.from(method, path); + Pair> routeAndVars = Route.from(request.method(), request.path()); if (routeAndVars != null) { try { ImmutableMap.Builder vars = ImmutableMap.builder(); - if (queryParams != null) { - vars.putAll(queryParams); - } + vars.putAll(request.queryParameters()); vars.putAll(routeAndVars.second()); - return handleRequest(routeAndVars.first(), vars.build(), body, responseType); + return handleRequest(routeAndVars.first(), vars.build(), request.body(), responseType); } catch (RuntimeException e) { configureResponseFromException(e, errorBuilder); @@ -577,7 +598,8 @@ public T execute( errorBuilder .responseCode(400) .withType("BadRequestException") - .withMessage(String.format("No route for request: %s %s", method, path)); + .withMessage( + String.format("No route for request: %s %s", request.method(), request.path())); } ErrorResponse error = errorBuilder.build(); @@ -587,13 +609,48 @@ public T execute( throw new RESTException("Unhandled error: %s", error); } + @Override + public void head( + String path, Supplier> headers, Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.HEAD, path, null, headers.get(), null); + execute(request, null, errorHandler, h -> {}); + } + + @Override + public void head(String path, Map headers, Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.HEAD, path, null, headers, null); + execute(request, null, errorHandler, h -> {}); + } + + @Override + public T delete( + String path, + Map queryParams, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.DELETE, path, queryParams, headers.get(), null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T delete( + String path, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.DELETE, path, null, headers.get(), null); + return execute(request, responseType, errorHandler, h -> {}); + } + @Override public T delete( String path, Class responseType, Map headers, Consumer errorHandler) { - return execute(HTTPMethod.DELETE, path, null, null, responseType, headers, errorHandler); + HTTPRequest request = buildRequest(HTTPMethod.DELETE, path, null, headers, null); + return execute(request, responseType, errorHandler, h -> {}); } @Override @@ -603,17 +660,39 @@ public T delete( Class responseType, Map headers, Consumer errorHandler) { - return execute(HTTPMethod.DELETE, path, queryParams, null, responseType, headers, errorHandler); + HTTPRequest request = buildRequest(HTTPMethod.DELETE, path, queryParams, headers, null); + return execute(request, responseType, errorHandler, h -> {}); } @Override - public T post( + public T get( + String path, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.GET, path, null, headers.get(), null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T get( String path, - RESTRequest body, Class responseType, Map headers, Consumer errorHandler) { - return execute(HTTPMethod.POST, path, null, body, responseType, headers, errorHandler); + HTTPRequest request = buildRequest(HTTPMethod.GET, path, null, headers, null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T get( + String path, + Map queryParams, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.GET, path, queryParams, headers.get(), null); + return execute(request, responseType, errorHandler, h -> {}); } @Override @@ -623,12 +702,65 @@ public T get( Class responseType, Map headers, Consumer errorHandler) { - return execute(HTTPMethod.GET, path, queryParams, null, responseType, headers, errorHandler); + HTTPRequest request = buildRequest(HTTPMethod.GET, path, queryParams, headers, null); + return execute(request, responseType, errorHandler, h -> {}); } @Override - public void head(String path, Map headers, Consumer errorHandler) { - execute(HTTPMethod.HEAD, path, null, null, null, headers, errorHandler); + public T post( + String path, + RESTRequest body, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers.get(), body); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T post( + String path, + RESTRequest body, + Class responseType, + Supplier> headers, + Consumer errorHandler, + Consumer> responseHeaders) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers.get(), body); + return execute(request, responseType, errorHandler, responseHeaders); + } + + @Override + public T post( + String path, + RESTRequest body, + Class responseType, + Map headers, + Consumer errorHandler, + Consumer> responseHeaders) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers, body); + return execute(request, responseType, errorHandler, responseHeaders); + } + + @Override + public T post( + String path, + RESTRequest body, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers, body); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T postForm( + String path, + Map formData, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers.get(), formData); + return execute(request, responseType, errorHandler, h -> {}); } @Override @@ -638,7 +770,8 @@ public T postForm( Class responseType, Map headers, Consumer errorHandler) { - return execute(HTTPMethod.POST, path, null, formData, responseType, headers, errorHandler); + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers, formData); + return execute(request, responseType, errorHandler, h -> {}); } @Override diff --git a/core/src/test/java/org/apache/iceberg/rest/RESTCatalogServlet.java b/core/src/test/java/org/apache/iceberg/rest/RESTCatalogServlet.java index f456bb4d354d..8997ce57ffdd 100644 --- a/core/src/test/java/org/apache/iceberg/rest/RESTCatalogServlet.java +++ b/core/src/test/java/org/apache/iceberg/rest/RESTCatalogServlet.java @@ -38,7 +38,6 @@ import org.apache.iceberg.exceptions.RESTException; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.io.CharStreams; -import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod; import org.apache.iceberg.rest.RESTCatalogAdapter.Route; import org.apache.iceberg.rest.responses.ErrorResponse; import org.apache.iceberg.util.Pair; @@ -96,15 +95,17 @@ protected void execute(ServletRequestContext context, HttpServletResponse respon } try { - Object responseBody = - restCatalogAdapter.execute( + + HTTPRequest request = + restCatalogAdapter.buildRequest( context.method(), context.path(), context.queryParams(), - context.body(), - context.route().responseClass(), context.headers(), - handle(response)); + context.body()); + Object responseBody = + restCatalogAdapter.execute( + request, context.route().responseClass(), handle(response), h -> {}); if (responseBody != null) { RESTObjectMapper.mapper().writeValue(response.getWriter(), responseBody); @@ -130,7 +131,7 @@ protected Consumer handle(HttpServletResponse response) { } public static class ServletRequestContext { - private HTTPMethod method; + private HTTPRequest.HTTPMethod method; private Route route; private String path; private Map headers; @@ -144,7 +145,7 @@ private ServletRequestContext(ErrorResponse errorResponse) { } private ServletRequestContext( - HTTPMethod method, + HTTPRequest.HTTPMethod method, Route route, String path, Map headers, @@ -159,7 +160,7 @@ private ServletRequestContext( } static ServletRequestContext from(HttpServletRequest request) throws IOException { - HTTPMethod method = HTTPMethod.valueOf(request.getMethod()); + HTTPRequest.HTTPMethod method = HTTPRequest.HTTPMethod.valueOf(request.getMethod()); String path = request.getRequestURI().substring(1); Pair> routeContext = Route.from(method, path); @@ -193,7 +194,7 @@ static ServletRequestContext from(HttpServletRequest request) throws IOException return new ServletRequestContext(method, route, path, headers, queryParams, requestBody); } - public HTTPMethod method() { + public HTTPRequest.HTTPMethod method() { return method; } diff --git a/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java b/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java index 1229639aba03..8b8d8ce7733a 100644 --- a/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java +++ b/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java @@ -42,12 +42,8 @@ import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; import org.apache.hc.client5.http.config.ConnectionConfig; import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; -import org.apache.hc.core5.http.EntityDetails; -import org.apache.hc.core5.http.HttpException; import org.apache.hc.core5.http.HttpHost; -import org.apache.hc.core5.http.HttpRequestInterceptor; import org.apache.hc.core5.http.HttpStatus; -import org.apache.hc.core5.http.protocol.HttpContext; import org.apache.iceberg.IcebergBuild; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.rest.responses.ErrorResponse; @@ -222,18 +218,6 @@ public void accept(ErrorResponse errorResponse) { } } - @Test - public void testDynamicHttpRequestInterceptorLoading() { - Map properties = ImmutableMap.of("key", "val"); - - HttpRequestInterceptor interceptor = - HTTPClient.loadInterceptorDynamically( - TestHttpRequestInterceptor.class.getName(), properties); - - assertThat(interceptor).isInstanceOf(TestHttpRequestInterceptor.class); - assertThat(((TestHttpRequestInterceptor) interceptor).properties).isEqualTo(properties); - } - @Test public void testSocketAndConnectionTimeoutSet() { long connectionTimeoutMs = 10L; @@ -444,17 +428,4 @@ public boolean equals(Object o) { return Objects.equals(id, item.id) && Objects.equals(data, item.data); } } - - public static class TestHttpRequestInterceptor implements HttpRequestInterceptor { - private Map properties; - - public void initialize(Map props) { - this.properties = props; - } - - @Override - public void process( - org.apache.hc.core5.http.HttpRequest request, EntityDetails entity, HttpContext context) - throws HttpException, IOException {} - } } diff --git a/core/src/test/java/org/apache/iceberg/rest/TestHTTPRequest.java b/core/src/test/java/org/apache/iceberg/rest/TestHTTPRequest.java new file mode 100644 index 000000000000..4028a1aa518a --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/rest/TestHTTPRequest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.net.URI; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; + +class TestHTTPRequest { + + @Test + void headers() { + HTTPRequest request = + ImmutableHTTPRequest.builder() + .baseUri(URI.create("http://localhost")) + .method(HTTPRequest.HTTPMethod.GET) + .path("path") + .putHeader("name", List.of("value")) + .build(); + assertThat(request.headers("name")).containsExactly("value"); + assertThat(request.headers("nonexistent")).isEmpty(); + } + + @Test + void containsHeader() { + HTTPRequest request = + ImmutableHTTPRequest.builder() + .baseUri(URI.create("http://localhost")) + .method(HTTPRequest.HTTPMethod.GET) + .path("path") + .headers(Map.of("k1", List.of("v1"), "k2", List.of())) + .build(); + assertThat(request.containsHeader("k1")).isTrue(); + assertThat(request.containsHeader("k2")).isFalse(); + assertThat(request.containsHeader("k3")).isFalse(); + } + + @Test + void putHeadersIfAbsent() { + HTTPRequest request = + ImmutableHTTPRequest.builder() + .baseUri(URI.create("http://localhost")) + .method(HTTPRequest.HTTPMethod.GET) + .path("path") + .headers(Map.of("k1", List.of("v1"), "k2", List.of("v2"))) + .build(); + request = request.putHeadersIfAbsent(Map.of("k1", "v1 update", "k3", "v3")); + assertThat(request.headers("k1")).containsExactly("v1"); + assertThat(request.headers("k2")).containsExactly("v2"); + assertThat(request.headers("k3")).containsExactly("v3"); + } +} diff --git a/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java b/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java index 973e394b30c7..0bd738ad653d 100644 --- a/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java +++ b/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java @@ -22,6 +22,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.times; @@ -32,11 +34,14 @@ import java.io.File; import java.io.IOException; import java.nio.file.Path; +import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; +import java.util.stream.Collectors; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.BaseTransaction; import org.apache.iceberg.CatalogProperties; @@ -65,7 +70,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; -import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; import org.apache.iceberg.rest.RESTSessionCatalog.SnapshotMode; import org.apache.iceberg.rest.auth.AuthSessionUtil; import org.apache.iceberg.rest.auth.OAuth2Properties; @@ -115,35 +120,31 @@ public void createCatalog() throws Exception { "in-memory", ImmutableMap.of(CatalogProperties.WAREHOUSE_LOCATION, warehouse.getAbsolutePath())); - Map catalogHeaders = - ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); - Map contextHeaders = - ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=user"); + Map> catalogHeaders = + ImmutableMap.of("Authorization", List.of("Bearer client-credentials-token:sub=catalog")); + Map> contextHeaders = + ImmutableMap.of("Authorization", List.of("Bearer client-credentials-token:sub=user")); RESTCatalogAdapter adaptor = new RESTCatalogAdapter(backendCatalog) { @Override public T execute( - RESTCatalogAdapter.HTTPMethod method, - String path, - Map queryParams, - Object body, + HTTPRequest request, Class responseType, - Map headers, - Consumer errorHandler) { + Consumer errorHandler, + Consumer> responseHeaders) { // this doesn't use a Mockito spy because this is used for catalog tests, which have // different method calls - if (!"v1/oauth/tokens".equals(path)) { - if ("v1/config".equals(path)) { - assertThat(headers).containsAllEntriesOf(catalogHeaders); + if (!"v1/oauth/tokens".equals(request.path())) { + if ("v1/config".equals(request.path())) { + assertThat(request.headers()).containsAllEntriesOf(catalogHeaders); } else { - assertThat(headers).containsAllEntriesOf(contextHeaders); + assertThat(request.headers()).containsAllEntriesOf(contextHeaders); } } - Object request = roundTripSerialize(body, "request"); - T response = - super.execute( - method, path, queryParams, request, responseType, headers, errorHandler); + Object body = roundTripSerialize(request.body(), "request"); + HTTPRequest req = ImmutableHTTPRequest.builder().from(request).body(body).build(); + T response = super.execute(req, responseType, errorHandler, responseHeaders); T responseAfterSerialization = roundTripSerialize(response, "response"); return responseAfterSerialization; } @@ -259,13 +260,12 @@ public void testConfigRoute() throws IOException { RESTClient testClient = new RESTCatalogAdapter(backendCatalog) { @Override - public T get( - String path, - Map queryParams, + public T execute( + HTTPRequest request, Class responseType, - Map headers, - Consumer errorHandler) { - if ("v1/config".equals(path)) { + Consumer errorHandler, + Consumer> responseHeaders) { + if ("v1/config".equals(request.path())) { return castResponse( responseType, ConfigResponse.builder() @@ -275,10 +275,11 @@ public T get( CatalogProperties.CACHE_ENABLED, "false", CatalogProperties.WAREHOUSE_LOCATION, - queryParams.get(CatalogProperties.WAREHOUSE_LOCATION) + "warehouse")) + request.queryParameters().get(CatalogProperties.WAREHOUSE_LOCATION) + + "warehouse")) .build()); } - return super.get(path, queryParams, responseType, headers, errorHandler); + return super.execute(request, responseType, errorHandler, responseHeaders); } }; @@ -337,27 +338,20 @@ public void testCatalogBasicBearerToken() { // the bearer token should be used for all interactions Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", catalogHeaders), any(), any(), - any(), - eq(catalogHeaders), any()); } @Test public void testCatalogCredentialNoOauth2ServerUri() { - Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -373,39 +367,29 @@ public void testCatalogCredentialNoOauth2ServerUri() { // no token or credential for catalog token exchange Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq("v1/oauth/tokens"), - any(), - any(), + reqMatcher(HTTPMethod.POST, "v1/oauth/tokens", Map.of()), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // no token or credential for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the catalog token for all interactions Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", catalogHeaders), any(), any(), - any(), - eq(catalogHeaders), any()); } @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogCredential(String oauth2ServerUri) { - Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -428,32 +412,23 @@ public void testCatalogCredential(String oauth2ServerUri) { // no token or credential for catalog token exchange Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, Map.of()), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // no token or credential for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the catalog token for all interactions Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", catalogHeaders), any(), any(), - eq(catalogHeaders), any()); } @@ -489,39 +464,29 @@ public void testCatalogBearerTokenWithClientCredential(String oauth2ServerUri) { // use the bearer token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the bearer token to fetch the context token Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // use the context token for table existence check Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", contextHeaders), any(), any(), - eq(contextHeaders), any()); } @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogCredentialWithClientCredential(String oauth2ServerUri) { - Map emptyHeaders = ImmutableMap.of(); Map contextHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=user"); Map catalogHeaders = @@ -552,42 +517,30 @@ public void testCatalogCredentialWithClientCredential(String oauth2ServerUri) { // call client credentials with no initial auth Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, Map.of()), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the client credential to fetch the context token Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // use the context token for table existence check Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", contextHeaders), any(), any(), - any(), - eq(contextHeaders), any()); } @@ -627,42 +580,30 @@ public void testCatalogBearerTokenAndCredentialWithClientCredential(String oauth // use the bearer token for client credentials Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, initHeaders), eq(OAuthTokenResponse.class), - eq(initHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the client credential to fetch the context token Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // use the context token for table existence check Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", contextHeaders), any(), any(), - eq(contextHeaders), any()); } @@ -822,12 +763,9 @@ private void testClientAuth( Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // token passes a static token. otherwise, validate a client credentials or token exchange @@ -835,34 +773,22 @@ private void testClientAuth( if (!credentials.containsKey("token")) { Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); } Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", expectedHeaders), any(), any(), - any(), - eq(expectedHeaders), any()); if (!optionalOAuthParams.isEmpty()) { Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat( - body -> - ((Map) body) - .keySet() - .containsAll(optionalOAuthParams.keySet())), + Mockito.argThat(body -> body.keySet().containsAll(optionalOAuthParams.keySet())), eq(OAuthTokenResponse.class), eq(catalogHeaders), any()); @@ -984,10 +910,7 @@ public void testTableSnapshotLoading() { Mockito.doAnswer(refsAnswer) .when(adapter) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -998,10 +921,7 @@ public void testTableSnapshotLoading() { // verify that the table was loaded with the refs argument verify(adapter, times(1)) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -1010,10 +930,7 @@ public void testTableSnapshotLoading() { assertThat(refsTables.snapshots()).containsExactlyInAnyOrderElementsOf(table.snapshots()); verify(adapter, times(1)) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "all")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "all")), eq(LoadTableResponse.class), any(), any()); @@ -1110,10 +1027,7 @@ public void testTableSnapshotLoadingWithDivergedBranches(String formatVersion) { Mockito.doAnswer(refsAnswer) .when(adapter) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -1124,10 +1038,7 @@ public void testTableSnapshotLoadingWithDivergedBranches(String formatVersion) { // verify that the table was loaded with the refs argument verify(adapter, times(1)) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -1137,10 +1048,7 @@ public void testTableSnapshotLoadingWithDivergedBranches(String formatVersion) { .containsExactlyInAnyOrderElementsOf(table.snapshots()); verify(adapter, times(1)) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "all")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "all")), eq(LoadTableResponse.class), any(), any()); @@ -1226,10 +1134,7 @@ public void lazySnapshotLoadingWithDivergedHistory() { Mockito.doAnswer(refsAnswer) .when(adapter) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -1267,23 +1172,17 @@ public void testTableAuth( Mockito.doAnswer(addTableConfig) .when(adapter) .execute( - eq(HTTPMethod.POST), - eq("v1/namespaces/ns/tables"), - any(), - any(), + reqMatcher(HTTPMethod.POST, "v1/namespaces/ns/tables", expectedContextHeaders), eq(LoadTableResponse.class), - eq(expectedContextHeaders), + any(), any()); Mockito.doAnswer(addTableConfig) .when(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/namespaces/ns/tables/table"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/namespaces/ns/tables/table", expectedContextHeaders), eq(LoadTableResponse.class), - eq(expectedContextHeaders), + any(), any()); SessionCatalog.SessionContext context = @@ -1324,33 +1223,24 @@ public void testTableAuth( Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // session client credentials flow Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // create table request Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq("v1/namespaces/ns/tables"), - any(), - any(), + reqMatcher(HTTPMethod.POST, "v1/namespaces/ns/tables", expectedContextHeaders), eq(LoadTableResponse.class), - eq(expectedContextHeaders), + any(), any()); // if the table returned a bearer token or a credential, there will be no token request @@ -1358,12 +1248,9 @@ public void testTableAuth( // token exchange to get a table token Mockito.verify(adapter, times(1)) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, expectedContextHeaders), eq(OAuthTokenResponse.class), - eq(expectedContextHeaders), + any(), any()); } @@ -1371,34 +1258,25 @@ public void testTableAuth( // load table from catalog + refresh loaded table Mockito.verify(adapter, times(2)) .execute( - eq(HTTPMethod.GET), - eq("v1/namespaces/ns/tables/table"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/namespaces/ns/tables/table", expectedTableHeaders), eq(LoadTableResponse.class), - eq(expectedTableHeaders), + any(), any()); } else { // load table from catalog Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/namespaces/ns/tables/table"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/namespaces/ns/tables/table", expectedContextHeaders), eq(LoadTableResponse.class), - eq(expectedContextHeaders), + any(), any()); // refresh loaded table Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/namespaces/ns/tables/table"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/namespaces/ns/tables/table", expectedTableHeaders), eq(LoadTableResponse.class), - eq(expectedTableHeaders), + any(), any()); } } @@ -1406,7 +1284,6 @@ public void testTableAuth( @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogTokenRefresh(String oauth2ServerUri) { - Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1426,14 +1303,7 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map contextCredentials = ImmutableMap.of(); SessionCatalog.SessionContext context = @@ -1458,23 +1328,17 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { // call client credentials with no initial auth Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, Map.of()), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // verify the first token exchange @@ -1486,12 +1350,14 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + catalogHeaders, + Map.of(), + firstRefreshRequest), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // verify that a second exchange occurs @@ -1508,12 +1374,14 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { "Bearer token-exchange-token:sub=client-credentials-token:sub=catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(secondRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + secondRefreshHeaders, + Map.of(), + secondRefreshRequest), eq(OAuthTokenResponse.class), - eq(secondRefreshHeaders), + any(), any()); }); } @@ -1521,7 +1389,6 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { - Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1541,14 +1408,7 @@ public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map contextCredentials = ImmutableMap.of(); SessionCatalog.SessionContext context = @@ -1574,25 +1434,30 @@ public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { assertThat(catalog.tableExists(TableIdentifier.of("ns", "table"))).isFalse(); // call client credentials with no initial auth + Map clientCredentialsRequest = + ImmutableMap.of( + "grant_type", "client_credentials", + "client_id", "catalog", + "client_secret", "secret", + "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + Map.of(), + Map.of(), + clientCredentialsRequest), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // verify the first token exchange @@ -1604,12 +1469,14 @@ public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + catalogHeaders, + Map.of(), + firstRefreshRequest), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // use the refreshed context token for table existence check @@ -1619,12 +1486,10 @@ public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { "Bearer token-exchange-token:sub=client-credentials-token:sub=catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher( + HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", refreshedCatalogHeader), any(), any(), - any(), - eq(refreshedCatalogHeader), any()); }); } @@ -1653,7 +1518,6 @@ public void testCatalogExpiredBearerTokenIsRefreshedWithCredential(String oauth2 // expires at epoch second = 1 String token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJleHAiOjF9.gQADTbdEv-rpDWKSkGLbmafyB5UUjTdm9B_1izpuZ6E"; - Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1680,24 +1544,25 @@ public void testCatalogExpiredBearerTokenIsRefreshedWithCredential(String oauth2 assertThat(catalog.tableExists(TableIdentifier.of("ns", "table"))).isFalse(); // call client credentials with no initial auth + Map clientCredentialsRequest = + ImmutableMap.of( + "grant_type", "client_credentials", + "client_id", "catalog", + "client_secret", "12345", + "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher( + HTTPMethod.POST, oauth2ServerUri, Map.of(), Map.of(), clientCredentialsRequest), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); Map firstRefreshRequest = @@ -1708,12 +1573,14 @@ public void testCatalogExpiredBearerTokenIsRefreshedWithCredential(String oauth2 "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + OAuth2Util.basicAuthHeaders(credential), + Map.of(), + firstRefreshRequest), eq(OAuthTokenResponse.class), - eq(OAuth2Util.basicAuthHeaders(credential)), + any(), any()); // verify that a second exchange occurs @@ -1725,22 +1592,24 @@ public void testCatalogExpiredBearerTokenIsRefreshedWithCredential(String oauth2 "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(secondRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + OAuth2Util.basicAuthHeaders(credential), + Map.of(), + secondRefreshRequest), eq(OAuthTokenResponse.class), - eq(OAuth2Util.basicAuthHeaders(credential)), + any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher( + HTTPMethod.HEAD, + "v1/namespaces/ns/tables/table", + Map.of("Authorization", "Bearer token-exchange-token:sub=" + token)), any(), any(), - any(), - eq(ImmutableMap.of("Authorization", "Bearer token-exchange-token:sub=" + token)), any()); } @@ -1767,29 +1636,23 @@ public void testCatalogValidBearerTokenIsNotRefreshed() { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher( + HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", OAuth2Util.authHeaders(token)), any(), any(), - eq(OAuth2Util.authHeaders(token)), any()); } @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth2ServerUri) { - Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1812,14 +1675,7 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map firstRefreshRequest = ImmutableMap.of( @@ -1831,11 +1687,9 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth // simulate that the token expired when it was about to be refreshed Mockito.doThrow(new RuntimeException("token expired")) .when(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + argThat(firstRefreshRequest::equals), eq(OAuthTokenResponse.class), eq(catalogHeaders), any()); @@ -1867,47 +1721,47 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth assertThat(catalog.tableExists(TableIdentifier.of("ns", "table"))).isFalse(); // call client credentials with no initial auth + Map clientCredentialsRequest = + ImmutableMap.of( + "grant_type", "client_credentials", + "client_id", "catalog", + "client_secret", "secret", + "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + Map.of(), + Map.of(), + clientCredentialsRequest), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // verify the first token exchange - since an exception is thrown, we're performing // retries Mockito.verify(adapter, times(2)) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + argThat(firstRefreshRequest::equals), eq(OAuthTokenResponse.class), eq(catalogHeaders), any()); - // here we make sure that the basic auth header is used after token refresh retries // failed Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + argThat(firstRefreshRequest::equals), eq(OAuthTokenResponse.class), eq(basicHeaders), any()); @@ -1919,12 +1773,10 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth "Bearer token-exchange-token:sub=client-credentials-token:sub=catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher( + HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", refreshedCatalogHeader), any(), any(), - any(), - eq(refreshedCatalogHeader), any()); }); } @@ -1932,7 +1784,6 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogWithCustomTokenScope(String oauth2ServerUri) { - Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1952,14 +1803,7 @@ public void testCatalogWithCustomTokenScope(String oauth2ServerUri) { Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map contextCredentials = ImmutableMap.of(); SessionCatalog.SessionContext context = @@ -1986,24 +1830,19 @@ public void testCatalogWithCustomTokenScope(String oauth2ServerUri) { () -> { // call client credentials with no initial auth Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - any(), + anyMap(), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + eq(Map.of()), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // verify the token exchange uses the right scope @@ -2014,11 +1853,9 @@ public void testCatalogWithCustomTokenScope(String oauth2ServerUri) { "subject_token_type", "urn:ietf:params:oauth:token-type:access_token", "scope", scope); Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + argThat(firstRefreshRequest::equals), eq(OAuthTokenResponse.class), eq(catalogHeaders), any()); @@ -2047,14 +1884,7 @@ public void testCatalogTokenRefreshDisabledWithToken(String oauth2ServerUri) { Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map contextCredentials = ImmutableMap.of(); SessionCatalog.SessionContext context = @@ -2076,12 +1906,9 @@ public void testCatalogTokenRefreshDisabledWithToken(String oauth2ServerUri) { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); } @@ -2122,23 +1949,18 @@ public void testCatalogTokenRefreshDisabledWithCredential(String oauth2ServerUri "scope", "catalog"); Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(fetchTokenFromCredential::equals), + argThat(fetchTokenFromCredential::equals), eq(OAuthTokenResponse.class), - eq(ImmutableMap.of()), + anyMap(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); } @@ -2311,20 +2133,14 @@ public void testPaginationForListNamespaces(int numberOfItems) { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", Map.of(), Map.of()), eq(ConfigResponse.class), any(), any()); Mockito.verify(adapter, times(numberOfItems)) .execute( - eq(HTTPMethod.POST), - eq("v1/namespaces"), - any(), - any(), + reqMatcher(HTTPMethod.POST, "v1/namespaces", Map.of(), Map.of()), eq(CreateNamespaceResponse.class), any(), any()); @@ -2375,20 +2191,18 @@ public void testPaginationForListTables(int numberOfItems) { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", Map.of(), Map.of()), eq(ConfigResponse.class), any(), any()); Mockito.verify(adapter, times(numberOfItems)) .execute( - eq(HTTPMethod.POST), - eq(String.format("v1/namespaces/%s/tables", namespaceName)), - any(), - any(), + reqMatcher( + HTTPMethod.POST, + String.format("v1/namespaces/%s/tables", namespaceName), + Map.of(), + Map.of()), eq(LoadTableResponse.class), any(), any()); @@ -2436,21 +2250,27 @@ public void testCleanupUncommitedFilesForCleanableFailures() { .build(); Table table = catalog.loadTable(TABLE); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); Mockito.doThrow(new NotAuthorizedException("not authorized")) .when(adapter) - .post(any(), any(), any(), any(Map.class), any()); + .execute(reqMatcher(HTTPMethod.POST), any(), any(), any()); assertThatThrownBy(() -> catalog.loadTable(TABLE).newFastAppend().appendFile(file).commit()) .isInstanceOf(NotAuthorizedException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(TABLE)), captor.capture(), any(), any(Map.class), any()); // Extract the UpdateTableRequest to determine the path of the manifest list that should be // cleaned up - UpdateTableRequest request = captor.getValue(); - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) request.updates().get(0); - assertThatThrownBy(() -> table.io().newInputFile(addSnapshot.snapshot().manifestListLocation())) - .isInstanceOf(NotFoundException.class); + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(TABLE)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest body = (UpdateTableRequest) req.body(); + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) body.updates().get(0); + assertThatThrownBy( + () -> table.io().newInputFile(addSnapshot.snapshot().manifestListLocation())) + .isInstanceOf(NotFoundException.class); + }); } @Test @@ -2465,21 +2285,26 @@ public void testNoCleanupForNonCleanableExceptions() { catalog.createTable(TABLE, SCHEMA); Table table = catalog.loadTable(TABLE); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); Mockito.doThrow(new ServiceFailureException("some service failure")) .when(adapter) - .post(any(), any(), any(), any(Map.class), any()); + .execute(reqMatcher(HTTPMethod.POST), any(), any(), any()); assertThatThrownBy(() -> catalog.loadTable(TABLE).newFastAppend().appendFile(FILE_A).commit()) .isInstanceOf(ServiceFailureException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(TABLE)), captor.capture(), any(), any(Map.class), any()); // Extract the UpdateTableRequest to determine the path of the manifest list that should still // exist even though the commit failed - UpdateTableRequest request = captor.getValue(); - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) request.updates().get(0); - assertThat(table.io().newInputFile(addSnapshot.snapshot().manifestListLocation()).exists()) - .isTrue(); + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(TABLE)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest body = (UpdateTableRequest) req.body(); + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) body.updates().get(0); + String manifestListLocation = addSnapshot.snapshot().manifestListLocation(); + assertThat(table.io().newInputFile(manifestListLocation).exists()).isTrue(); + }); } @Test @@ -2493,32 +2318,38 @@ public void testCleanupCleanableExceptionsCreate() { catalog.createTable(TABLE, SCHEMA); TableIdentifier newTable = TableIdentifier.of(TABLE.namespace(), "some_table"); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); Mockito.doThrow(new NotAuthorizedException("not authorized")) .when(adapter) - .post(eq(RESOURCE_PATHS.table(newTable)), any(), any(), any(Map.class), any()); + .execute(reqMatcher(HTTPMethod.POST, RESOURCE_PATHS.table(newTable)), any(), any(), any()); Transaction createTableTransaction = catalog.newCreateTableTransaction(newTable, SCHEMA); createTableTransaction.newAppend().appendFile(FILE_A).commit(); assertThatThrownBy(createTableTransaction::commitTransaction) .isInstanceOf(NotAuthorizedException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(newTable)), captor.capture(), any(), any(Map.class), any()); - UpdateTableRequest request = captor.getValue(); - Optional appendSnapshot = - request.updates().stream() - .filter(update -> update instanceof MetadataUpdate.AddSnapshot) - .findFirst(); - - assertThat(appendSnapshot).isPresent(); - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) appendSnapshot.get(); - assertThatThrownBy( - () -> - catalog - .loadTable(TABLE) - .io() - .newInputFile(addSnapshot.snapshot().manifestListLocation())) - .isInstanceOf(NotFoundException.class); + + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(newTable)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest body = (UpdateTableRequest) req.body(); + Optional appendSnapshot = + body.updates().stream() + .filter(update -> update instanceof MetadataUpdate.AddSnapshot) + .findFirst(); + + assertThat(appendSnapshot).isPresent(); + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) appendSnapshot.get(); + assertThatThrownBy( + () -> + catalog + .loadTable(TABLE) + .io() + .newInputFile(addSnapshot.snapshot().manifestListLocation())) + .isInstanceOf(NotFoundException.class); + }); } @Test @@ -2534,29 +2365,32 @@ public void testNoCleanupForNonCleanableCreateTransaction() { TableIdentifier newTable = TableIdentifier.of(TABLE.namespace(), "some_table"); Mockito.doThrow(new ServiceFailureException("some service failure")) .when(adapter) - .post(eq(RESOURCE_PATHS.table(newTable)), any(), any(), any(Map.class), any()); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); + .execute(reqMatcher(HTTPMethod.POST, RESOURCE_PATHS.table(newTable)), any(), any(), any()); + Transaction createTableTransaction = catalog.newCreateTableTransaction(newTable, SCHEMA); createTableTransaction.newAppend().appendFile(FILE_A).commit(); assertThatThrownBy(createTableTransaction::commitTransaction) .isInstanceOf(ServiceFailureException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(newTable)), captor.capture(), any(), any(Map.class), any()); - UpdateTableRequest request = captor.getValue(); - Optional appendSnapshot = - request.updates().stream() - .filter(update -> update instanceof MetadataUpdate.AddSnapshot) - .findFirst(); - assertThat(appendSnapshot).isPresent(); - - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) appendSnapshot.get(); - assertThat( - catalog - .loadTable(TABLE) - .io() - .newInputFile(addSnapshot.snapshot().manifestListLocation()) - .exists()) - .isTrue(); + + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(newTable)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest body = (UpdateTableRequest) req.body(); + Optional appendSnapshot = + body.updates().stream() + .filter(update -> update instanceof MetadataUpdate.AddSnapshot) + .findFirst(); + assertThat(appendSnapshot).isPresent(); + + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) appendSnapshot.get(); + String manifestListLocation = addSnapshot.snapshot().manifestListLocation(); + assertThat(catalog.loadTable(TABLE).io().newInputFile(manifestListLocation).exists()) + .isTrue(); + }); } @Test @@ -2569,32 +2403,35 @@ public void testCleanupCleanableExceptionsReplace() { } catalog.createTable(TABLE, SCHEMA); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); Mockito.doThrow(new NotAuthorizedException("not authorized")) .when(adapter) - .post(eq(RESOURCE_PATHS.table(TABLE)), any(), any(), any(Map.class), any()); + .execute(reqMatcher(HTTPMethod.POST, RESOURCE_PATHS.table(TABLE)), any(), any(), any()); Transaction replaceTableTransaction = catalog.newReplaceTableTransaction(TABLE, SCHEMA, false); replaceTableTransaction.newAppend().appendFile(FILE_A).commit(); assertThatThrownBy(replaceTableTransaction::commitTransaction) .isInstanceOf(NotAuthorizedException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(TABLE)), captor.capture(), any(), any(Map.class), any()); - UpdateTableRequest request = captor.getValue(); - Optional appendSnapshot = - request.updates().stream() - .filter(update -> update instanceof MetadataUpdate.AddSnapshot) - .findFirst(); - - assertThat(appendSnapshot).isPresent(); - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) appendSnapshot.get(); - assertThatThrownBy( - () -> - catalog - .loadTable(TABLE) - .io() - .newInputFile(addSnapshot.snapshot().manifestListLocation())) - .isInstanceOf(NotFoundException.class); + + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(TABLE)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest request = (UpdateTableRequest) req.body(); + Optional appendSnapshot = + request.updates().stream() + .filter(update -> update instanceof MetadataUpdate.AddSnapshot) + .findFirst(); + + assertThat(appendSnapshot).isPresent(); + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) appendSnapshot.get(); + String manifestListLocation = addSnapshot.snapshot().manifestListLocation(); + assertThatThrownBy( + () -> catalog.loadTable(TABLE).io().newInputFile(manifestListLocation)) + .isInstanceOf(NotFoundException.class); + }); } @Test @@ -2609,29 +2446,32 @@ public void testNoCleanupForNonCleanableReplaceTransaction() { catalog.createTable(TABLE, SCHEMA); Mockito.doThrow(new ServiceFailureException("some service failure")) .when(adapter) - .post(eq(RESOURCE_PATHS.table(TABLE)), any(), any(), any(Map.class), any()); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); + .execute(reqMatcher(HTTPMethod.POST, RESOURCE_PATHS.table(TABLE)), any(), any(), any()); + Transaction replaceTableTransaction = catalog.newReplaceTableTransaction(TABLE, SCHEMA, false); replaceTableTransaction.newAppend().appendFile(FILE_A).commit(); assertThatThrownBy(replaceTableTransaction::commitTransaction) .isInstanceOf(ServiceFailureException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(TABLE)), captor.capture(), any(), any(Map.class), any()); - UpdateTableRequest request = captor.getValue(); - Optional appendSnapshot = - request.updates().stream() - .filter(update -> update instanceof MetadataUpdate.AddSnapshot) - .findFirst(); - assertThat(appendSnapshot).isPresent(); - - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) appendSnapshot.get(); - assertThat( - catalog - .loadTable(TABLE) - .io() - .newInputFile(addSnapshot.snapshot().manifestListLocation()) - .exists()) - .isTrue(); + + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(TABLE)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest request = (UpdateTableRequest) req.body(); + Optional appendSnapshot = + request.updates().stream() + .filter(update -> update instanceof MetadataUpdate.AddSnapshot) + .findFirst(); + assertThat(appendSnapshot).isPresent(); + + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) appendSnapshot.get(); + String manifestListLocation = addSnapshot.snapshot().manifestListLocation(); + assertThat(catalog.loadTable(TABLE).io().newInputFile(manifestListLocation).exists()) + .isTrue(); + }); } private RESTCatalog catalog(RESTCatalogAdapter adapter) { @@ -2643,4 +2483,56 @@ private RESTCatalog catalog(RESTCatalogAdapter adapter) { CatalogProperties.FILE_IO_IMPL, "org.apache.iceberg.inmemory.InMemoryFileIO")); return catalog; } + + static HTTPRequest reqMatcher(HTTPMethod method) { + return argThat(req -> req.method() == method); + } + + static HTTPRequest reqMatcher(HTTPMethod method, String path) { + return argThat(req -> req.method() == method && req.path().equals(path)); + } + + static HTTPRequest reqMatcher(HTTPMethod method, String path, Map headers) { + return argThat( + req -> + req.method() == method + && req.path().equals(path) + && req.headers().equals(toMultiMap(headers))); + } + + static HTTPRequest reqMatcher( + HTTPMethod method, String path, Map headers, Map parameters) { + return argThat( + req -> + req.method() == method + && req.path().equals(path) + && req.headers().equals(toMultiMap(headers)) + && req.queryParameters().equals(parameters)); + } + + static HTTPRequest reqMatcher( + HTTPMethod method, + String path, + Map headers, + Map parameters, + Object body) { + return argThat( + req -> + req.method() == method + && req.path().equals(path) + && req.headers().equals(toMultiMap(headers)) + && req.queryParameters().equals(parameters) + && Objects.equals(req.body(), body)); + } + + private static List allRequests(RESTCatalogAdapter adapter) { + ArgumentCaptor captor = ArgumentCaptor.forClass(HTTPRequest.class); + verify(adapter, atLeastOnce()).execute(captor.capture(), any(), any(), any()); + return captor.getAllValues(); + } + + private static Map> toMultiMap(Map map) { + return map.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> List.of(e.getValue()))); + } } diff --git a/core/src/test/java/org/apache/iceberg/rest/TestRESTUtil.java b/core/src/test/java/org/apache/iceberg/rest/TestRESTUtil.java index c7667d90ac6f..1ca4aabb822c 100644 --- a/core/src/test/java/org/apache/iceberg/rest/TestRESTUtil.java +++ b/core/src/test/java/org/apache/iceberg/rest/TestRESTUtil.java @@ -20,11 +20,20 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import java.net.URI; import java.util.Map; +import java.util.stream.Stream; import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.exceptions.RESTException; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; +import org.apache.iceberg.rest.requests.CreateNamespaceRequest; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; public class TestRESTUtil { @@ -139,4 +148,102 @@ public void testOAuth2FormDataDecoding() { assertThat(RESTUtil.decodeFormData(formString)).isEqualTo(expected); } + + @ParameterizedTest + @MethodSource("validRequestUris") + public void validRequestUris(HTTPRequest request, URI expected) { + assertThat(RESTUtil.buildRequestUri(request)).isEqualTo(expected); + } + + public static Stream validRequestUris() { + return Stream.of( + Arguments.of( + ImmutableHTTPRequest.builder() + .baseUri(URI.create("http://localhost:8080/foo")) + .method(HTTPMethod.GET) + .path("v1/namespaces/ns/tables/") // trailing slash should be removed + .putQueryParameter("pageToken", "1234") + .putQueryParameter("pageSize", "10") + .build(), + URI.create( + "http://localhost:8080/foo/v1/namespaces/ns/tables?pageToken=1234&pageSize=10")), + Arguments.of( + ImmutableHTTPRequest.builder() + .baseUri(URI.create("http://localhost:8080/foo")) + .method(HTTPMethod.GET) + .path("https://authserver.com/token") // absolute path HTTPS + .build(), + URI.create("https://authserver.com/token")), + Arguments.of( + ImmutableHTTPRequest.builder() + .baseUri(URI.create("http://localhost:8080/foo")) + .method(HTTPMethod.GET) + .path("http://authserver.com/token") // absolute path HTTP + .build(), + URI.create("http://authserver.com/token"))); + } + + @Test + public void buildRequestUriFailures() { + HTTPRequest request = + ImmutableHTTPRequest.builder() + .baseUri(URI.create("http://localhost")) + .method(HTTPMethod.GET) + .path("/v1/namespaces") // wrong leading slash + .build(); + assertThatThrownBy(() -> RESTUtil.buildRequestUri(request)) + .isInstanceOf(RESTException.class) + .hasMessage( + "Received a malformed path for a REST request: /v1/namespaces. Paths should not start with /"); + HTTPRequest request2 = + ImmutableHTTPRequest.builder() + .baseUri(URI.create("http://localhost")) + .method(HTTPMethod.GET) + .path(" not a valid path") // wrong path + .build(); + assertThatThrownBy(() -> RESTUtil.buildRequestUri(request2)) + .isInstanceOf(RESTException.class) + .hasMessage( + "Failed to create request URI from base http://localhost/ not a valid path, params {}"); + } + + @ParameterizedTest + @MethodSource("encodeRequestBody") + public void encodeRequestBody(HTTPRequest request, String expected) { + assertThat(RESTUtil.encodeRequestBody(request)).isEqualTo(expected); + } + + public static Stream encodeRequestBody() { + return Stream.of( + // form data + Arguments.of( + ImmutableHTTPRequest.builder() + .baseUri(URI.create("http://localhost")) + .method(HTTPMethod.POST) + .path("token") + .body( + ImmutableMap.of( + "grant_type", "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token", "token", + "subject_token_type", "urn:ietf:params:oauth:token-type:access_token", + "scope", "catalog")) + .build(), + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&" + + "subject_token=token&" + + "subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&" + + "scope=catalog"), + // JSON + Arguments.of( + ImmutableHTTPRequest.builder() + .baseUri(URI.create("http://localhost")) + .method(HTTPMethod.POST) + .path("v1/namespaces/ns") // trailing slash should be removed + .body( + CreateNamespaceRequest.builder() + .withNamespace(Namespace.of("ns")) + .setProperties(ImmutableMap.of("prop1", "value1")) + .build()) + .build(), + "{\"namespace\":[\"ns\"],\"properties\":{\"prop1\":\"value1\"}}")); + } } diff --git a/core/src/test/java/org/apache/iceberg/rest/TestRESTViewCatalog.java b/core/src/test/java/org/apache/iceberg/rest/TestRESTViewCatalog.java index 85ccdc8f5ddd..81061ebb5a8f 100644 --- a/core/src/test/java/org/apache/iceberg/rest/TestRESTViewCatalog.java +++ b/core/src/test/java/org/apache/iceberg/rest/TestRESTViewCatalog.java @@ -18,6 +18,7 @@ */ package org.apache.iceberg.rest; +import static org.apache.iceberg.rest.TestRESTCatalog.reqMatcher; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -38,7 +39,7 @@ import org.apache.iceberg.catalog.TableIdentifier; import org.apache.iceberg.inmemory.InMemoryCatalog; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; -import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; import org.apache.iceberg.rest.responses.ConfigResponse; import org.apache.iceberg.rest.responses.ErrorResponse; import org.apache.iceberg.rest.responses.ListTablesResponse; @@ -77,17 +78,13 @@ public void createCatalog() throws Exception { new RESTCatalogAdapter(backendCatalog) { @Override public T execute( - HTTPMethod method, - String path, - Map queryParams, - Object body, + HTTPRequest request, Class responseType, - Map headers, - Consumer errorHandler) { - Object request = roundTripSerialize(body, "request"); - T response = - super.execute( - method, path, queryParams, request, responseType, headers, errorHandler); + Consumer errorHandler, + Consumer> responseHeaders) { + Object body = roundTripSerialize(request.body(), "request"); + HTTPRequest req = ImmutableHTTPRequest.builder().from(request).body(body).build(); + T response = super.execute(req, responseType, errorHandler, responseHeaders); T responseAfterSerialization = roundTripSerialize(response, "response"); return responseAfterSerialization; } @@ -182,21 +179,11 @@ public void testPaginationForListViews(int numberOfItems) { assertThat(views).hasSize(numberOfItems); Mockito.verify(adapter) - .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), - eq(ConfigResponse.class), - any(), - any()); + .execute(reqMatcher(HTTPMethod.GET, "v1/config"), eq(ConfigResponse.class), any(), any()); Mockito.verify(adapter, times(numberOfItems)) .execute( - eq(HTTPMethod.POST), - eq(String.format("v1/namespaces/%s/views", namespaceName)), - any(), - any(), + reqMatcher(HTTPMethod.POST, String.format("v1/namespaces/%s/views", namespaceName)), eq(LoadViewResponse.class), any(), any()); diff --git a/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthManagers.java b/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthManagers.java new file mode 100644 index 000000000000..6a171c7b3667 --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthManagers.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.util.Map; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class TestAuthManagers { + + private final PrintStream standardErr = System.err; + private final ByteArrayOutputStream streamCaptor = new ByteArrayOutputStream(); + + @BeforeEach + public void before() { + System.setErr(new PrintStream(streamCaptor)); + } + + @AfterEach + public void after() { + System.setErr(standardErr); + } + + @Test + void oauth2Explicit() { + try (AuthManager manager = + AuthManagers.loadAuthManager( + "test", Map.of(AuthProperties.AUTH_TYPE, AuthProperties.AUTH_TYPE_OAUTH2))) { + assertThat(manager).isInstanceOf(OAuth2Manager.class); + } + assertThat(streamCaptor.toString()) + .contains("Loading AuthManager implementation: org.apache.iceberg.rest.auth.OAuth2Manager"); + } + + @Test + void oauth2InferredFromToken() { + try (AuthManager manager = + AuthManagers.loadAuthManager("test", Map.of(OAuth2Properties.TOKEN, "irrelevant"))) { + assertThat(manager).isInstanceOf(OAuth2Manager.class); + } + assertThat(streamCaptor.toString()) + .contains( + "Inferring rest.auth.type=oauth2 since property token was provided. " + + "Please explicitly set rest.auth.type to avoid this warning."); + assertThat(streamCaptor.toString()) + .contains("Loading AuthManager implementation: org.apache.iceberg.rest.auth.OAuth2Manager"); + } + + @Test + void oauth2InferredFromCredential() { + try (AuthManager manager = + AuthManagers.loadAuthManager("test", Map.of(OAuth2Properties.CREDENTIAL, "irrelevant"))) { + assertThat(manager).isInstanceOf(OAuth2Manager.class); + } + assertThat(streamCaptor.toString()) + .contains( + "Inferring rest.auth.type=oauth2 since property credential was provided. " + + "Please explicitly set rest.auth.type to avoid this warning."); + assertThat(streamCaptor.toString()) + .contains("Loading AuthManager implementation: org.apache.iceberg.rest.auth.OAuth2Manager"); + } + + @Test + @SuppressWarnings("resource") + void sigV4Explicit() { + assertThatThrownBy( + () -> + AuthManagers.loadAuthManager( + "test", Map.of(AuthProperties.AUTH_TYPE, AuthProperties.AUTH_TYPE_SIGV4))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining( + "Cannot initialize AuthManager implementation org.apache.iceberg.aws.RESTSigV4AuthManager"); + assertThat(streamCaptor.toString()) + .contains( + "Loading AuthManager implementation: org.apache.iceberg.aws.RESTSigV4AuthManager"); + } + + @Test + @SuppressWarnings("resource") + void sigV4Legacy() { + assertThatThrownBy( + () -> AuthManagers.loadAuthManager("test", Map.of("rest.sigv4-enabled", "true"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining( + "Cannot initialize AuthManager implementation org.apache.iceberg.aws.RESTSigV4AuthManager"); + assertThat(streamCaptor.toString()) + .contains( + "The property rest.sigv4-enabled is deprecated and will be removed in a future release. " + + "Please use the property rest.auth.type=sigv4 instead."); + assertThat(streamCaptor.toString()) + .contains( + "Loading AuthManager implementation: org.apache.iceberg.aws.RESTSigV4AuthManager"); + } + + @Test + void noop() { + try (AuthManager manager = AuthManagers.loadAuthManager("test", Map.of())) { + assertThat(manager).isInstanceOf(NoopAuthManager.class); + } + assertThat(streamCaptor.toString()) + .contains( + "Loading AuthManager implementation: org.apache.iceberg.rest.auth.NoopAuthManager"); + } + + @Test + void noopExplicit() { + try (AuthManager manager = + AuthManagers.loadAuthManager( + "test", Map.of(AuthProperties.AUTH_TYPE, AuthProperties.AUTH_TYPE_NONE))) { + assertThat(manager).isInstanceOf(NoopAuthManager.class); + } + assertThat(streamCaptor.toString()) + .contains( + "Loading AuthManager implementation: org.apache.iceberg.rest.auth.NoopAuthManager"); + } + + @Test + void basicExplicit() { + try (AuthManager manager = + AuthManagers.loadAuthManager( + "test", Map.of(AuthProperties.AUTH_TYPE, AuthProperties.AUTH_TYPE_BASIC))) { + assertThat(manager).isInstanceOf(BasicAuthManager.class); + } + assertThat(streamCaptor.toString()) + .contains( + "Loading AuthManager implementation: org.apache.iceberg.rest.auth.BasicAuthManager"); + } + + @Test + @SuppressWarnings("resource") + void nonExistentAuthManager() { + assertThatThrownBy( + () -> AuthManagers.loadAuthManager("test", Map.of(AuthProperties.AUTH_TYPE, "unknown"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Cannot initialize AuthManager implementation unknown"); + } +} diff --git a/core/src/test/java/org/apache/iceberg/rest/auth/TestBasicAuthManager.java b/core/src/test/java/org/apache/iceberg/rest/auth/TestBasicAuthManager.java new file mode 100644 index 000000000000..00231421e29b --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/rest/auth/TestBasicAuthManager.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.InstanceOfAssertFactories.type; + +import java.util.Map; +import org.junit.jupiter.api.Test; + +class TestBasicAuthManager { + + @Test + void missingUsername() { + try (AuthManager authManager = new BasicAuthManager()) { + assertThatThrownBy(() -> authManager.catalogSession(null, Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid username: missing required property %s", AuthProperties.BASIC_USERNAME); + } + } + + @Test + void missingPassword() { + try (AuthManager authManager = new BasicAuthManager()) { + Map properties = Map.of(AuthProperties.BASIC_USERNAME, "alice"); + assertThatThrownBy(() -> authManager.catalogSession(null, properties)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid password: missing required property %s", AuthProperties.BASIC_PASSWORD); + } + } + + @Test + void success() { + Map properties = + Map.of(AuthProperties.BASIC_USERNAME, "alice", AuthProperties.BASIC_PASSWORD, "secret"); + try (AuthManager authManager = new BasicAuthManager(); + AuthSession session = authManager.catalogSession(null, properties)) { + assertThat(session) + .isNotNull() + .asInstanceOf(type(DefaultAuthSession.class)) + .extracting(DefaultAuthSession::headers) + .isEqualTo(Map.of("Authorization", "Basic YWxpY2U6c2VjcmV0")); + } + } +} diff --git a/gcp/src/test/java/org/apache/iceberg/gcp/gcs/OAuth2RefreshCredentialsHandlerTest.java b/gcp/src/test/java/org/apache/iceberg/gcp/gcs/OAuth2RefreshCredentialsHandlerTest.java index c538745f2767..ad27f005690e 100644 --- a/gcp/src/test/java/org/apache/iceberg/gcp/gcs/OAuth2RefreshCredentialsHandlerTest.java +++ b/gcp/src/test/java/org/apache/iceberg/gcp/gcs/OAuth2RefreshCredentialsHandlerTest.java @@ -26,7 +26,6 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; import org.apache.iceberg.exceptions.BadRequestException; -import org.apache.iceberg.exceptions.RESTException; import org.apache.iceberg.gcp.GCPProperties; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.rest.HttpMethod; @@ -75,8 +74,8 @@ public void invalidOrMissingUri() { ImmutableMap.of( GCPProperties.GCS_OAUTH2_REFRESH_CREDENTIALS_ENDPOINT, "invalid uri")) .refreshAccessToken()) - .isInstanceOf(RESTException.class) - .hasMessageStartingWith("Failed to create request URI from base invalid uri"); + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Illegal character in path at index 7: invalid uri"); } @Test