diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java b/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java index e249d3ff1dec..792a6ff27bea 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java @@ -26,7 +26,9 @@ import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.rest.ErrorHandlers; import org.apache.iceberg.rest.HTTPClient; +import org.apache.iceberg.rest.HTTPHeaders; import org.apache.iceberg.rest.RESTClient; +import org.apache.iceberg.rest.auth.DefaultAuthSession; import org.apache.iceberg.rest.auth.OAuth2Properties; import org.apache.iceberg.rest.auth.OAuth2Util; import org.apache.iceberg.rest.credentials.Credential; @@ -74,7 +76,14 @@ private RESTClient httpClient() { if (null == client) { synchronized (this) { if (null == client) { - client = HTTPClient.builder(properties).uri(properties.get(URI)).build(); + DefaultAuthSession authSession = + DefaultAuthSession.of( + HTTPHeaders.of(OAuth2Util.authHeaders(properties.get(OAuth2Properties.TOKEN)))); + client = + HTTPClient.builder(properties) + .uri(properties.get(URI)) + .withAuthSession(authSession) + .build(); } } } @@ -88,7 +97,7 @@ private LoadCredentialsResponse fetchCredentials() { properties.get(URI), null, LoadCredentialsResponse.class, - OAuth2Util.authHeaders(properties.get(OAuth2Properties.TOKEN)), + Map.of(), ErrorHandlers.defaultErrorHandler()); } diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java b/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java index 806c52420f89..8f5a00b49daf 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/signer/S3V4RestSignerClient.java @@ -206,21 +206,26 @@ private AuthSession authSession() { return authSessionCache() .get( token, - id -> - AuthSession.fromAccessToken( - httpClient(), - tokenRefreshExecutor(), - token, - expiresAtMillis(properties()), - new AuthSession( - ImmutableMap.of(), - AuthConfig.builder() - .token(token) - .credential(credential()) - .scope(SCOPE) - .oauth2ServerUri(oauth2ServerUri()) - .optionalOAuthParams(optionalOAuthParams()) - .build()))); + id -> { + // this client will be reused for token refreshes; it must contain an empty auth + // session in order to avoid interfering with refreshed tokens + RESTClient refreshClient = + httpClient().withAuthSession(org.apache.iceberg.rest.auth.AuthSession.EMPTY); + return AuthSession.fromAccessToken( + refreshClient, + tokenRefreshExecutor(), + token, + expiresAtMillis(properties()), + new AuthSession( + ImmutableMap.of(), + AuthConfig.builder() + .token(token) + .credential(credential()) + .scope(SCOPE) + .oauth2ServerUri(oauth2ServerUri()) + .optionalOAuthParams(optionalOAuthParams()) + .build())); + }); } if (credentialProvided()) { @@ -238,16 +243,20 @@ private AuthSession authSession() { .optionalOAuthParams(optionalOAuthParams()) .build()); long startTimeMillis = System.currentTimeMillis(); + // this client will be reused for token refreshes; it must contain an empty auth + // session in order to avoid interfering with refreshed tokens + RESTClient refreshClient = + httpClient().withAuthSession(org.apache.iceberg.rest.auth.AuthSession.EMPTY); OAuthTokenResponse authResponse = OAuth2Util.fetchToken( - httpClient(), + refreshClient, session.headers(), credential(), SCOPE, oauth2ServerUri(), optionalOAuthParams()); return AuthSession.fromTokenResponse( - httpClient(), tokenRefreshExecutor(), authResponse, startTimeMillis, session); + refreshClient, tokenRefreshExecutor(), authResponse, startTimeMillis, session); }); } @@ -338,11 +347,12 @@ public SdkHttpFullRequest sign( Consumer> responseHeadersConsumer = responseHeaders::putAll; S3SignResponse s3SignResponse = httpClient() + .withAuthSession(authSession()) .post( endpoint(), remoteSigningRequest, S3SignResponse.class, - () -> authSession().headers(), + Map.of(), ErrorHandlers.defaultErrorHandler(), responseHeadersConsumer); diff --git a/aws/src/test/java/org/apache/iceberg/aws/TestRESTSigV4Signer.java b/aws/src/test/java/org/apache/iceberg/aws/TestRESTSigV4Signer.java index 88623edd9334..cc8873e30ea7 100644 --- a/aws/src/test/java/org/apache/iceberg/aws/TestRESTSigV4Signer.java +++ b/aws/src/test/java/org/apache/iceberg/aws/TestRESTSigV4Signer.java @@ -27,6 +27,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.rest.HTTPClient; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.auth.OAuth2Util; import org.apache.iceberg.rest.responses.ConfigResponse; import org.apache.iceberg.rest.responses.OAuthTokenResponse; @@ -64,6 +65,7 @@ public static void beforeClass() { HTTPClient.builder(properties) .uri("http://localhost:" + mockServer.getLocalPort()) .withHeader(HttpHeaders.AUTHORIZATION, "Bearer existing_token") + .withAuthSession(AuthSession.EMPTY) .build(); } diff --git a/core/src/main/java/org/apache/iceberg/rest/BaseHTTPClient.java b/core/src/main/java/org/apache/iceberg/rest/BaseHTTPClient.java new file mode 100644 index 000000000000..7be6648ddef0 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/BaseHTTPClient.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest; + +import java.util.Map; +import java.util.function.Consumer; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; +import org.apache.iceberg.rest.auth.AuthSession; +import org.apache.iceberg.rest.responses.ErrorResponse; + +/** + * A base class for {@link RESTClient} implementations. + * + *

All methods in {@link RESTClient} are implemented in the same way: first, an {@link + * HTTPRequest} is {@linkplain #buildRequest(HTTPMethod, String, Map, Map, Object) built from the + * method arguments}, then {@linkplain #execute(HTTPRequest, Class, Consumer, Consumer) executed}. + * + *

This allows subclasses to provide a consistent way to execute all requests, regardless of the + * method or arguments. + */ +public abstract class BaseHTTPClient implements RESTClient { + + @Override + public abstract RESTClient withAuthSession(AuthSession session); + + @Override + public void head(String path, Map headers, Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.HEAD, path, null, headers, null); + execute(request, null, errorHandler, h -> {}); + } + + @Override + public T delete( + String path, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.DELETE, path, null, headers, null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T delete( + String path, + Map queryParams, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.DELETE, path, queryParams, headers, null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T get( + String path, + Map queryParams, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.GET, path, queryParams, headers, null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T post( + String path, + RESTRequest body, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers, body); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T post( + String path, + RESTRequest body, + Class responseType, + Map headers, + Consumer errorHandler, + Consumer> responseHeaders) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers, body); + return execute(request, responseType, errorHandler, responseHeaders); + } + + @Override + public T postForm( + String path, + Map formData, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers, formData); + return execute(request, responseType, errorHandler, h -> {}); + } + + protected abstract HTTPRequest buildRequest( + HTTPMethod method, + String path, + Map queryParams, + Map headers, + Object body); + + protected abstract T execute( + HTTPRequest request, + Class responseType, + Consumer errorHandler, + Consumer> responseHeaders); +} diff --git a/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java b/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java index 4391d75ec710..dfb1d5827ef9 100644 --- a/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java +++ b/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java @@ -23,14 +23,11 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.net.URI; -import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; -import java.util.stream.Collectors; import org.apache.hc.client5.http.auth.CredentialsProvider; -import org.apache.hc.client5.http.classic.methods.HttpUriRequest; import org.apache.hc.client5.http.classic.methods.HttpUriRequestBase; import org.apache.hc.client5.http.config.ConnectionConfig; import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; @@ -45,14 +42,11 @@ import org.apache.hc.core5.http.HttpHost; import org.apache.hc.core5.http.HttpRequestInterceptor; import org.apache.hc.core5.http.HttpStatus; -import org.apache.hc.core5.http.Method; import org.apache.hc.core5.http.ParseException; import org.apache.hc.core5.http.impl.EnglishReasonPhraseCatalog; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.io.entity.StringEntity; -import org.apache.hc.core5.http.message.BasicHeader; import org.apache.hc.core5.io.CloseMode; -import org.apache.hc.core5.net.URIBuilder; import org.apache.iceberg.IcebergBuild; import org.apache.iceberg.common.DynConstructors; import org.apache.iceberg.common.DynMethods; @@ -60,13 +54,15 @@ import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.responses.ErrorResponse; import org.apache.iceberg.util.PropertyUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** An HttpClient for usage with the REST catalog. */ -public class HTTPClient implements RESTClient { +public class HTTPClient extends BaseHTTPClient { private static final Logger LOG = LoggerFactory.getLogger(HTTPClient.class); private static final String SIGV4_ENABLED = "rest.sigv4-enabled"; @@ -88,33 +84,31 @@ public class HTTPClient implements RESTClient { @VisibleForTesting static final String REST_SOCKET_TIMEOUT_MS = "rest.client.socket-timeout-ms"; - private final String uri; + private final URI baseUri; private final CloseableHttpClient httpClient; + private final Map baseHeaders; private final ObjectMapper mapper; + private final AuthSession authSession; private HTTPClient( - String uri, + URI baseUri, HttpHost proxy, CredentialsProvider proxyCredsProvider, Map baseHeaders, ObjectMapper objectMapper, HttpRequestInterceptor requestInterceptor, Map properties, - HttpClientConnectionManager connectionManager) { - this.uri = uri; + HttpClientConnectionManager connectionManager, + AuthSession session) { + this.baseUri = baseUri; + this.baseHeaders = baseHeaders; this.mapper = objectMapper; + this.authSession = session; HttpClientBuilder clientBuilder = HttpClients.custom(); clientBuilder.setConnectionManager(connectionManager); - if (baseHeaders != null) { - clientBuilder.setDefaultHeaders( - baseHeaders.entrySet().stream() - .map(e -> new BasicHeader(e.getKey(), e.getValue())) - .collect(Collectors.toList())); - } - if (requestInterceptor != null) { clientBuilder.addRequestInterceptorLast(requestInterceptor); } @@ -133,6 +127,25 @@ private HTTPClient( this.httpClient = clientBuilder.build(); } + /** + * Constructor for creating a child HTTPClient associated with an AuthSession. The returned child + * shares the same base uri, mapper, and HTTP client as the parent, thus not requiring any + * additional resource allocation. + */ + private HTTPClient(HTTPClient parent, AuthSession authSession) { + this.baseUri = parent.baseUri; + this.httpClient = parent.httpClient; + this.mapper = parent.mapper; + this.baseHeaders = parent.baseHeaders; + this.authSession = authSession; + } + + @Override + public HTTPClient withAuthSession(AuthSession session) { + Preconditions.checkNotNull(session, "Invalid auth session: null"); + return new HTTPClient(this, session); + } + private static String extractResponseBodyAsString(CloseableHttpResponse response) { try { if (response.getEntity() == null) { @@ -214,92 +227,64 @@ private static void throwFailure( throw new RESTException("Unhandled error: %s", errorResponse); } - private URI buildUri(String path, Map params) { - // if full path is provided, use the input path as path - if (path.startsWith("/")) { - throw new RESTException( - "Received a malformed path for a REST request: %s. Paths should not start with /", path); - } - String fullPath = - (path.startsWith("https://") || path.startsWith("http://")) - ? path - : String.format("%s/%s", uri, path); - try { - URIBuilder builder = new URIBuilder(fullPath); - if (params != null) { - params.forEach(builder::addParameter); - } - return builder.build(); - } catch (URISyntaxException e) { - throw new RESTException( - "Failed to create request URI from base %s, params %s", fullPath, params); - } - } - - /** - * Method to execute an HTTP request and process the corresponding response. - * - * @param method - HTTP method, such as GET, POST, HEAD, etc. - * @param queryParams - A map of query parameters - * @param path - URI to send the request to - * @param requestBody - Content to place in the request body - * @param responseType - Class of the Response type. Needs to have serializer registered with - * ObjectMapper - * @param errorHandler - Error handler delegated for HTTP responses which handles server error - * responses - * @param - Class type of the response for deserialization. Must be registered with the - * ObjectMapper. - * @return The response entity, parsed and converted to its type T - */ - private T execute( - Method method, + @Override + protected HTTPRequest buildRequest( + HTTPMethod method, String path, Map queryParams, - Object requestBody, - Class responseType, Map headers, - Consumer errorHandler) { - return execute( - method, path, queryParams, requestBody, responseType, headers, errorHandler, h -> {}); + Object body) { + + ImmutableHTTPRequest.Builder builder = + ImmutableHTTPRequest.builder() + .baseUri(baseUri) + .mapper(mapper) + .method(method) + .path(path) + .body(body) + .queryParameters(queryParams == null ? Map.of() : queryParams); + + Map allHeaders = Maps.newLinkedHashMap(); + if (headers != null) { + allHeaders.putAll(headers); + } + + allHeaders.putIfAbsent(HttpHeaders.ACCEPT, ContentType.APPLICATION_JSON.getMimeType()); + + // Many systems require that content type is set regardless and will fail, + // even on an empty bodied request. + // Encode maps as form data (application/x-www-form-urlencoded), + // and other requests are assumed to contain JSON bodies (application/json). + ContentType mimeType = + body instanceof Map + ? ContentType.APPLICATION_FORM_URLENCODED + : ContentType.APPLICATION_JSON; + allHeaders.putIfAbsent(HttpHeaders.CONTENT_TYPE, mimeType.getMimeType()); + + // Apply base headers now to mimic the behavior of + // org.apache.hc.client5.http.protocol.RequestDefaultHeaders + // We want these headers applied *before* the AuthSession authenticates the request. + if (baseHeaders != null) { + baseHeaders.forEach(allHeaders::putIfAbsent); + } + + Preconditions.checkState(authSession != null, "Invalid auth session: null"); + return authSession.authenticate(builder.headers(HTTPHeaders.of(allHeaders)).build()); } - /** - * Method to execute an HTTP request and process the corresponding response. - * - * @param method - HTTP method, such as GET, POST, HEAD, etc. - * @param queryParams - A map of query parameters - * @param path - URL to send the request to - * @param requestBody - Content to place in the request body - * @param responseType - Class of the Response type. Needs to have serializer registered with - * ObjectMapper - * @param errorHandler - Error handler delegated for HTTP responses which handles server error - * responses - * @param responseHeaders The consumer of the response headers - * @param - Class type of the response for deserialization. Must be registered with the - * ObjectMapper. - * @return The response entity, parsed and converted to its type T - */ - private T execute( - Method method, - String path, - Map queryParams, - Object requestBody, + @Override + protected T execute( + HTTPRequest req, Class responseType, - Map headers, Consumer errorHandler, Consumer> responseHeaders) { - HttpUriRequestBase request = new HttpUriRequestBase(method.name(), buildUri(path, queryParams)); - - if (requestBody instanceof Map) { - // encode maps as form data, application/x-www-form-urlencoded - addRequestHeaders(request, headers, ContentType.APPLICATION_FORM_URLENCODED.getMimeType()); - request.setEntity(toFormEncoding((Map) requestBody)); - } else if (requestBody != null) { - // other request bodies are serialized as JSON, application/json - addRequestHeaders(request, headers, ContentType.APPLICATION_JSON.getMimeType()); - request.setEntity(toJson(requestBody)); - } else { - addRequestHeaders(request, headers, ContentType.APPLICATION_JSON.getMimeType()); + HttpUriRequestBase request = new HttpUriRequestBase(req.method().name(), req.requestUri()); + + req.headers().entries().forEach(e -> request.addHeader(e.name(), e.value())); + + String encodedBody = req.encodedBody(); + if (encodedBody != null) { + request.setEntity(new StringEntity(encodedBody)); } try (CloseableHttpResponse response = httpClient.execute(request)) { @@ -326,7 +311,7 @@ private T execute( if (responseBody == null) { throw new RESTException( "Invalid (null) response body for request (expected %s): method=%s, path=%s, status=%d", - responseType.getSimpleName(), method.name(), path, response.getCode()); + responseType.getSimpleName(), req.method(), req.path(), response.getCode()); } try { @@ -339,88 +324,19 @@ private T execute( responseType.getSimpleName()); } } catch (IOException e) { - throw new RESTException(e, "Error occurred while processing %s request", method); + throw new RESTException(e, "Error occurred while processing %s request", req.method()); } } - @Override - public void head(String path, Map headers, Consumer errorHandler) { - execute(Method.HEAD, path, null, null, null, headers, errorHandler); - } - - @Override - public T get( - String path, - Map queryParams, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(Method.GET, path, queryParams, null, responseType, headers, errorHandler); - } - - @Override - public T post( - String path, - RESTRequest body, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(Method.POST, path, null, body, responseType, headers, errorHandler); - } - - @Override - public T post( - String path, - RESTRequest body, - Class responseType, - Map headers, - Consumer errorHandler, - Consumer> responseHeaders) { - return execute( - Method.POST, path, null, body, responseType, headers, errorHandler, responseHeaders); - } - - @Override - public T delete( - String path, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(Method.DELETE, path, null, null, responseType, headers, errorHandler); - } - - @Override - public T delete( - String path, - Map queryParams, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(Method.DELETE, path, queryParams, null, responseType, headers, errorHandler); - } - - @Override - public T postForm( - String path, - Map formData, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(Method.POST, path, null, formData, responseType, headers, errorHandler); - } - - private void addRequestHeaders( - HttpUriRequest request, Map requestHeaders, String bodyMimeType) { - request.setHeader(HttpHeaders.ACCEPT, ContentType.APPLICATION_JSON.getMimeType()); - // Many systems require that content type is set regardless and will fail, even on an empty - // bodied request. - request.setHeader(HttpHeaders.CONTENT_TYPE, bodyMimeType); - requestHeaders.forEach(request::setHeader); - } - @Override public void close() throws IOException { - httpClient.close(CloseMode.GRACEFUL); + try { + if (authSession != null) { + authSession.close(); + } + } finally { + httpClient.close(CloseMode.GRACEFUL); + } } @VisibleForTesting @@ -510,18 +426,29 @@ public static Builder builder(Map properties) { public static class Builder { private final Map properties; private final Map baseHeaders = Maps.newHashMap(); - private String uri; + private URI uri; private ObjectMapper mapper = RESTObjectMapper.mapper(); private HttpHost proxy; private CredentialsProvider proxyCredentialsProvider; + private AuthSession authSession; private Builder(Map properties) { this.properties = properties; } - public Builder uri(String path) { - Preconditions.checkNotNull(path, "Invalid uri for http client: null"); - this.uri = RESTUtil.stripTrailingSlash(path); + public Builder uri(String baseUri) { + Preconditions.checkNotNull(baseUri, "Invalid uri for http client: null"); + try { + this.uri = URI.create(RESTUtil.stripTrailingSlash(baseUri)); + } catch (IllegalArgumentException e) { + throw new RESTException(e, "Failed to create request URI from base %s", baseUri); + } + return this; + } + + public Builder uri(URI baseUri) { + Preconditions.checkNotNull(baseUri, "Invalid uri for http client: null"); + this.uri = baseUri; return this; } @@ -553,6 +480,11 @@ public Builder withObjectMapper(ObjectMapper objectMapper) { return this; } + public Builder withAuthSession(AuthSession session) { + this.authSession = session; + return this; + } + public HTTPClient build() { withHeader(CLIENT_VERSION_HEADER, IcebergBuild.fullVersion()); withHeader(CLIENT_GIT_COMMIT_SHORT_HEADER, IcebergBuild.gitCommitShortId()); @@ -576,19 +508,8 @@ public HTTPClient build() { mapper, interceptor, properties, - configureConnectionManager(properties)); + configureConnectionManager(properties), + authSession); } } - - private StringEntity toJson(Object requestBody) { - try { - return new StringEntity(mapper.writeValueAsString(requestBody), StandardCharsets.UTF_8); - } catch (JsonProcessingException e) { - throw new RESTException(e, "Failed to write request body: %s", requestBody); - } - } - - private StringEntity toFormEncoding(Map formData) { - return new StringEntity(RESTUtil.encodeFormData(formData), StandardCharsets.UTF_8); - } } diff --git a/core/src/main/java/org/apache/iceberg/rest/RESTClient.java b/core/src/main/java/org/apache/iceberg/rest/RESTClient.java index 0f17d9a127e2..2843972fee45 100644 --- a/core/src/main/java/org/apache/iceberg/rest/RESTClient.java +++ b/core/src/main/java/org/apache/iceberg/rest/RESTClient.java @@ -23,6 +23,7 @@ import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.responses.ErrorResponse; /** Interface for a basic HTTP Client for interfacing with the REST catalog. */ @@ -158,4 +159,9 @@ T postForm( Class responseType, Map headers, Consumer errorHandler); + + /** Returns a REST client that authenticates requests using the given session. */ + default RESTClient withAuthSession(AuthSession session) { + return this; + } } diff --git a/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java b/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java index d98dc99495f6..33c3d3396c6b 100644 --- a/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java +++ b/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java @@ -71,6 +71,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.rest.auth.AuthConfig; +import org.apache.iceberg.rest.auth.DefaultAuthSession; import org.apache.iceberg.rest.auth.OAuth2Properties; import org.apache.iceberg.rest.auth.OAuth2Util; import org.apache.iceberg.rest.auth.OAuth2Util.AuthSession; @@ -242,9 +243,10 @@ public void initialize(String name, Map unresolved) { } String oauth2ServerUri = props.getOrDefault(OAuth2Properties.OAUTH2_SERVER_URI, ResourcePaths.tokens()); - try (RESTClient initClient = clientBuilder.apply(props)) { - Map initHeaders = - RESTUtil.merge(configHeaders(props), OAuth2Util.authHeaders(initToken)); + try (DefaultAuthSession initSession = + DefaultAuthSession.of(HTTPHeaders.of(OAuth2Util.authHeaders(initToken))); + RESTClient initClient = clientBuilder.apply(props).withAuthSession(initSession)) { + Map initHeaders = configHeaders(props); if (hasCredential) { authResponse = OAuth2Util.fetchToken( @@ -283,7 +285,6 @@ public void initialize(String name, Map unresolved) { mergedProps, OAuth2Properties.TOKEN_REFRESH_ENABLED, OAuth2Properties.TOKEN_REFRESH_ENABLED_DEFAULT); - this.client = clientBuilder.apply(mergedProps); this.paths = ResourcePaths.forCatalogProperties(mergedProps); String token = mergedProps.get(OAuth2Properties.TOKEN); @@ -296,14 +297,19 @@ public void initialize(String name, Map unresolved) { .oauth2ServerUri(oauth2ServerUri) .optionalOAuthParams(optionalOAuthParams) .build()); + + this.client = clientBuilder.apply(mergedProps).withAuthSession(catalogAuth); + if (authResponse != null) { this.catalogAuth = AuthSession.fromTokenResponse( client, tokenRefreshExecutor(name), authResponse, startTimeMillis, catalogAuth); + this.client = client.withAuthSession(catalogAuth); } else if (token != null) { this.catalogAuth = AuthSession.fromAccessToken( client, tokenRefreshExecutor(name), token, expiresAtMillis(mergedProps), catalogAuth); + this.client = client.withAuthSession(catalogAuth); } this.pageSize = PropertyUtil.propertyAsNullableInt(mergedProps, REST_PAGE_SIZE); diff --git a/core/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java b/core/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java index 2fb4defd1224..3d0984a09092 100644 --- a/core/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java +++ b/core/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java @@ -18,7 +18,9 @@ */ package org.apache.iceberg.rest; +import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; +import java.net.URI; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -50,6 +52,8 @@ import org.apache.iceberg.relocated.com.google.common.base.Splitter; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.requests.CommitTransactionRequest; import org.apache.iceberg.rest.requests.CreateNamespaceRequest; import org.apache.iceberg.rest.requests.CreateTableRequest; @@ -73,7 +77,7 @@ import org.apache.iceberg.util.PropertyUtil; /** Adaptor class to translate REST requests into {@link Catalog} API calls. */ -public class RESTCatalogAdapter implements RESTClient { +public class RESTCatalogAdapter extends BaseHTTPClient { private static final Splitter SLASH = Splitter.on('/'); private static final Map, Integer> EXCEPTION_ERROR_CODES = @@ -98,6 +102,8 @@ public class RESTCatalogAdapter implements RESTClient { private final SupportsNamespaces asNamespaceCatalog; private final ViewCatalog asViewCatalog; + private AuthSession authSession = AuthSession.EMPTY; + public RESTCatalogAdapter(Catalog catalog) { this.catalog = catalog; this.asNamespaceCatalog = @@ -105,13 +111,6 @@ public RESTCatalogAdapter(Catalog catalog) { this.asViewCatalog = catalog instanceof ViewCatalog ? (ViewCatalog) catalog : null; } - enum HTTPMethod { - GET, - HEAD, - POST, - DELETE - } - enum Route { TOKENS(HTTPMethod.POST, "v1/oauth/tokens", null, OAuthTokenResponse.class), SEPARATE_AUTH_TOKENS_URI( @@ -280,6 +279,12 @@ private static OAuthTokenResponse handleOAuthRequest(Object body) { } } + @Override + public RESTClient withAuthSession(AuthSession session) { + this.authSession = session; + return this; + } + @SuppressWarnings({"MethodLength", "checkstyle:CyclomaticComplexity"}) public T handleRequest( Route route, Map vars, Object body, Class responseType) { @@ -567,25 +572,49 @@ private static void commitTransaction(Catalog catalog, CommitTransactionRequest transactions.forEach(Transaction::commitTransaction); } - public T execute( + @Override + protected HTTPRequest buildRequest( HTTPMethod method, String path, Map queryParams, - Object body, - Class responseType, Map headers, - Consumer errorHandler) { + Object body) { + URI baseUri = URI.create("https://localhost:8080"); + ObjectMapper mapper = RESTObjectMapper.mapper(); + ImmutableHTTPRequest.Builder builder = + ImmutableHTTPRequest.builder() + .baseUri(baseUri) + .mapper(mapper) + .method(method) + .path(path) + .body(body); + + if (queryParams != null) { + builder.queryParameters(queryParams); + } + + if (headers != null) { + builder.headers(HTTPHeaders.of(headers)); + } + + return authSession.authenticate(builder.build()); + } + + @Override + protected T execute( + HTTPRequest request, + Class responseType, + Consumer errorHandler, + Consumer> responseHeaders) { ErrorResponse.Builder errorBuilder = ErrorResponse.builder(); - Pair> routeAndVars = Route.from(method, path); + Pair> routeAndVars = Route.from(request.method(), request.path()); if (routeAndVars != null) { try { ImmutableMap.Builder vars = ImmutableMap.builder(); - if (queryParams != null) { - vars.putAll(queryParams); - } + vars.putAll(request.queryParameters()); vars.putAll(routeAndVars.second()); - return handleRequest(routeAndVars.first(), vars.build(), body, responseType); + return handleRequest(routeAndVars.first(), vars.build(), request.body(), responseType); } catch (RuntimeException e) { configureResponseFromException(e, errorBuilder); @@ -595,7 +624,8 @@ public T execute( errorBuilder .responseCode(400) .withType("BadRequestException") - .withMessage(String.format("No route for request: %s %s", method, path)); + .withMessage( + String.format("No route for request: %s %s", request.method(), request.path())); } ErrorResponse error = errorBuilder.build(); @@ -605,60 +635,6 @@ public T execute( throw new RESTException("Unhandled error: %s", error); } - @Override - public T delete( - String path, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(HTTPMethod.DELETE, path, null, null, responseType, headers, errorHandler); - } - - @Override - public T delete( - String path, - Map queryParams, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(HTTPMethod.DELETE, path, queryParams, null, responseType, headers, errorHandler); - } - - @Override - public T post( - String path, - RESTRequest body, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(HTTPMethod.POST, path, null, body, responseType, headers, errorHandler); - } - - @Override - public T get( - String path, - Map queryParams, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(HTTPMethod.GET, path, queryParams, null, responseType, headers, errorHandler); - } - - @Override - public void head(String path, Map headers, Consumer errorHandler) { - execute(HTTPMethod.HEAD, path, null, null, null, headers, errorHandler); - } - - @Override - public T postForm( - String path, - Map formData, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(HTTPMethod.POST, path, null, formData, responseType, headers, errorHandler); - } - @Override public void close() throws IOException { // The calling test is responsible for closing the underlying catalog backing this REST catalog diff --git a/core/src/test/java/org/apache/iceberg/rest/RESTCatalogServlet.java b/core/src/test/java/org/apache/iceberg/rest/RESTCatalogServlet.java index f456bb4d354d..97c8b6134482 100644 --- a/core/src/test/java/org/apache/iceberg/rest/RESTCatalogServlet.java +++ b/core/src/test/java/org/apache/iceberg/rest/RESTCatalogServlet.java @@ -38,7 +38,7 @@ import org.apache.iceberg.exceptions.RESTException; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.io.CharStreams; -import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; import org.apache.iceberg.rest.RESTCatalogAdapter.Route; import org.apache.iceberg.rest.responses.ErrorResponse; import org.apache.iceberg.util.Pair; @@ -96,15 +96,17 @@ protected void execute(ServletRequestContext context, HttpServletResponse respon } try { - Object responseBody = - restCatalogAdapter.execute( + + HTTPRequest request = + restCatalogAdapter.buildRequest( context.method(), context.path(), context.queryParams(), - context.body(), - context.route().responseClass(), context.headers(), - handle(response)); + context.body()); + Object responseBody = + restCatalogAdapter.execute( + request, context.route().responseClass(), handle(response), h -> {}); if (responseBody != null) { RESTObjectMapper.mapper().writeValue(response.getWriter(), responseBody); diff --git a/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java b/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java index a7a0b68c8f40..13d555598a3b 100644 --- a/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java +++ b/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java @@ -52,6 +52,7 @@ import org.apache.hc.core5.http.protocol.HttpContext; import org.apache.iceberg.IcebergBuild; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.responses.ErrorResponse; import org.apache.iceberg.rest.responses.ErrorResponseParser; import org.junit.jupiter.api.AfterAll; @@ -84,7 +85,8 @@ public class TestHTTPClient { @BeforeAll public static void beforeClass() { mockServer = startClientAndServer(PORT); - restClient = HTTPClient.builder(ImmutableMap.of()).uri(URI).build(); + restClient = + HTTPClient.builder(ImmutableMap.of()).uri(URI).withAuthSession(AuthSession.EMPTY).build(); icebergBuildGitCommitShort = IcebergBuild.gitCommitShortId(); icebergBuildFullVersion = IcebergBuild.fullVersion(); } @@ -143,6 +145,7 @@ public void testProxyServer() throws IOException { HTTPClient.builder(ImmutableMap.of()) .uri(URI) .withProxy("localhost", proxyPort) + .withAuthSession(AuthSession.EMPTY) .build()) { String path = "v1/config"; HttpRequest mockRequest = @@ -188,18 +191,19 @@ public void testProxyAuthenticationFailure() throws IOException { new AuthScope(proxy), new UsernamePasswordCredentials(authorizedUsername, invalidPassword.toCharArray())); - try (ClientAndServer proxyServer = - startClientAndServer( - new Configuration() - .proxyAuthenticationUsername(authorizedUsername) - .proxyAuthenticationPassword(authorizedPassword), - proxyPort); - RESTClient clientWithProxy = + try (RESTClient clientWithProxy = HTTPClient.builder(ImmutableMap.of()) .uri(URI) .withProxy(proxyHostName, proxyPort) .withProxyCredentialsProvider(credentialsProvider) - .build()) { + .withAuthSession(AuthSession.EMPTY) + .build(); + ClientAndServer proxyServer = + startClientAndServer( + new Configuration() + .proxyAuthenticationUsername(authorizedUsername) + .proxyAuthenticationPassword(authorizedPassword), + proxyPort)) { ErrorHandler onError = new ErrorHandler() { @@ -291,7 +295,8 @@ public void testSocketTimeout() throws IOException { ImmutableMap.of(HTTPClient.REST_SOCKET_TIMEOUT_MS, String.valueOf(socketTimeoutMs)); String path = "socket/timeout/path"; - try (HTTPClient client = HTTPClient.builder(properties).uri(URI).build()) { + try (HTTPClient client = + HTTPClient.builder(properties).uri(URI).withAuthSession(AuthSession.EMPTY).build()) { HttpRequest mockRequest = request() .withPath("/" + path) diff --git a/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java b/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java index 768d6c3777ee..6a5f22075c6e 100644 --- a/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java +++ b/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java @@ -22,6 +22,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.times; @@ -32,7 +34,9 @@ import java.io.File; import java.io.IOException; import java.nio.file.Path; +import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.UUID; import java.util.concurrent.TimeUnit; @@ -65,7 +69,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; -import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; import org.apache.iceberg.rest.RESTSessionCatalog.SnapshotMode; import org.apache.iceberg.rest.auth.AuthSessionUtil; import org.apache.iceberg.rest.auth.OAuth2Properties; @@ -115,35 +119,31 @@ public void createCatalog() throws Exception { "in-memory", ImmutableMap.of(CatalogProperties.WAREHOUSE_LOCATION, warehouse.getAbsolutePath())); - Map catalogHeaders = - ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); - Map contextHeaders = - ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=user"); + HTTPHeaders catalogHeaders = + HTTPHeaders.of(Map.of("Authorization", "Bearer client-credentials-token:sub=catalog")); + HTTPHeaders contextHeaders = + HTTPHeaders.of(Map.of("Authorization", "Bearer client-credentials-token:sub=user")); RESTCatalogAdapter adaptor = new RESTCatalogAdapter(backendCatalog) { @Override public T execute( - RESTCatalogAdapter.HTTPMethod method, - String path, - Map queryParams, - Object body, + HTTPRequest request, Class responseType, - Map headers, - Consumer errorHandler) { + Consumer errorHandler, + Consumer> responseHeaders) { // this doesn't use a Mockito spy because this is used for catalog tests, which have // different method calls - if (!"v1/oauth/tokens".equals(path)) { - if ("v1/config".equals(path)) { - assertThat(headers).containsAllEntriesOf(catalogHeaders); + if (!"v1/oauth/tokens".equals(request.path())) { + if ("v1/config".equals(request.path())) { + assertThat(request.headers().entries()).containsAll(catalogHeaders.entries()); } else { - assertThat(headers).containsAllEntriesOf(contextHeaders); + assertThat(request.headers().entries()).containsAll(contextHeaders.entries()); } } - Object request = roundTripSerialize(body, "request"); - T response = - super.execute( - method, path, queryParams, request, responseType, headers, errorHandler); + Object body = roundTripSerialize(request.body(), "request"); + HTTPRequest req = ImmutableHTTPRequest.builder().from(request).body(body).build(); + T response = super.execute(req, responseType, errorHandler, responseHeaders); T responseAfterSerialization = roundTripSerialize(response, "response"); return responseAfterSerialization; } @@ -259,13 +259,12 @@ public void testConfigRoute() throws IOException { RESTClient testClient = new RESTCatalogAdapter(backendCatalog) { @Override - public T get( - String path, - Map queryParams, + public T execute( + HTTPRequest request, Class responseType, - Map headers, - Consumer errorHandler) { - if ("v1/config".equals(path)) { + Consumer errorHandler, + Consumer> responseHeaders) { + if ("v1/config".equals(request.path())) { return castResponse( responseType, ConfigResponse.builder() @@ -275,10 +274,11 @@ public T get( CatalogProperties.CACHE_ENABLED, "false", CatalogProperties.WAREHOUSE_LOCATION, - queryParams.get(CatalogProperties.WAREHOUSE_LOCATION) + "warehouse")) + request.queryParameters().get(CatalogProperties.WAREHOUSE_LOCATION) + + "warehouse")) .build()); } - return super.get(path, queryParams, responseType, headers, errorHandler); + return super.execute(request, responseType, errorHandler, responseHeaders); } }; @@ -337,21 +337,15 @@ public void testCatalogBasicBearerToken() { // the bearer token should be used for all interactions Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", catalogHeaders), any(), any(), - any(), - eq(catalogHeaders), any()); } @@ -373,32 +367,23 @@ public void testCatalogCredentialNoOauth2ServerUri() { // no token or credential for catalog token exchange Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq("v1/oauth/tokens"), - any(), - any(), + reqMatcher(HTTPMethod.POST, "v1/oauth/tokens", emptyHeaders), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // no token or credential for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the catalog token for all interactions Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", catalogHeaders), any(), any(), - any(), - eq(catalogHeaders), any()); } @@ -428,32 +413,23 @@ public void testCatalogCredential(String oauth2ServerUri) { // no token or credential for catalog token exchange Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, emptyHeaders), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // no token or credential for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the catalog token for all interactions Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", catalogHeaders), any(), any(), - eq(catalogHeaders), any()); } @@ -489,32 +465,23 @@ public void testCatalogBearerTokenWithClientCredential(String oauth2ServerUri) { // use the bearer token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the bearer token to fetch the context token Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // use the context token for table existence check Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", contextHeaders), any(), any(), - any(), - eq(contextHeaders), any()); } @@ -552,42 +519,30 @@ public void testCatalogCredentialWithClientCredential(String oauth2ServerUri) { // call client credentials with no initial auth Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, emptyHeaders), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the client credential to fetch the context token Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // use the context token for table existence check Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", contextHeaders), any(), any(), - eq(contextHeaders), any()); } @@ -627,42 +582,30 @@ public void testCatalogBearerTokenAndCredentialWithClientCredential(String oauth // use the bearer token for client credentials Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, initHeaders), eq(OAuthTokenResponse.class), - eq(initHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the client credential to fetch the context token Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // use the context token for table existence check Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", contextHeaders), any(), any(), - any(), - eq(contextHeaders), any()); } @@ -822,12 +765,9 @@ private void testClientAuth( Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // token passes a static token. otherwise, validate a client credentials or token exchange @@ -835,34 +775,22 @@ private void testClientAuth( if (!credentials.containsKey("token")) { Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); } Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", expectedHeaders), any(), any(), - eq(expectedHeaders), any()); if (!optionalOAuthParams.isEmpty()) { Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat( - body -> - ((Map) body) - .keySet() - .containsAll(optionalOAuthParams.keySet())), + Mockito.argThat(body -> body.keySet().containsAll(optionalOAuthParams.keySet())), eq(OAuthTokenResponse.class), eq(catalogHeaders), any()); @@ -984,10 +912,7 @@ public void testTableSnapshotLoading() { Mockito.doAnswer(refsAnswer) .when(adapter) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -998,10 +923,7 @@ public void testTableSnapshotLoading() { // verify that the table was loaded with the refs argument verify(adapter, times(1)) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -1010,10 +932,7 @@ public void testTableSnapshotLoading() { assertThat(refsTables.snapshots()).containsExactlyInAnyOrderElementsOf(table.snapshots()); verify(adapter, times(1)) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "all")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "all")), eq(LoadTableResponse.class), any(), any()); @@ -1110,10 +1029,7 @@ public void testTableSnapshotLoadingWithDivergedBranches(String formatVersion) { Mockito.doAnswer(refsAnswer) .when(adapter) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -1124,10 +1040,7 @@ public void testTableSnapshotLoadingWithDivergedBranches(String formatVersion) { // verify that the table was loaded with the refs argument verify(adapter, times(1)) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -1137,10 +1050,7 @@ public void testTableSnapshotLoadingWithDivergedBranches(String formatVersion) { .containsExactlyInAnyOrderElementsOf(table.snapshots()); verify(adapter, times(1)) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "all")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "all")), eq(LoadTableResponse.class), any(), any()); @@ -1226,10 +1136,7 @@ public void lazySnapshotLoadingWithDivergedHistory() { Mockito.doAnswer(refsAnswer) .when(adapter) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -1267,23 +1174,17 @@ public void testTableAuth( Mockito.doAnswer(addTableConfig) .when(adapter) .execute( - eq(HTTPMethod.POST), - eq("v1/namespaces/ns/tables"), - any(), - any(), + reqMatcher(HTTPMethod.POST, "v1/namespaces/ns/tables", expectedContextHeaders), eq(LoadTableResponse.class), - eq(expectedContextHeaders), + any(), any()); Mockito.doAnswer(addTableConfig) .when(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/namespaces/ns/tables/table"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/namespaces/ns/tables/table", expectedContextHeaders), eq(LoadTableResponse.class), - eq(expectedContextHeaders), + any(), any()); SessionCatalog.SessionContext context = @@ -1324,33 +1225,24 @@ public void testTableAuth( Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // session client credentials flow Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // create table request Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq("v1/namespaces/ns/tables"), - any(), - any(), + reqMatcher(HTTPMethod.POST, "v1/namespaces/ns/tables", expectedContextHeaders), eq(LoadTableResponse.class), - eq(expectedContextHeaders), + any(), any()); // if the table returned a bearer token or a credential, there will be no token request @@ -1358,12 +1250,9 @@ public void testTableAuth( // token exchange to get a table token Mockito.verify(adapter, times(1)) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, expectedContextHeaders), eq(OAuthTokenResponse.class), - eq(expectedContextHeaders), + any(), any()); } @@ -1371,34 +1260,25 @@ public void testTableAuth( // load table from catalog + refresh loaded table Mockito.verify(adapter, times(2)) .execute( - eq(HTTPMethod.GET), - eq("v1/namespaces/ns/tables/table"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/namespaces/ns/tables/table", expectedTableHeaders), eq(LoadTableResponse.class), - eq(expectedTableHeaders), + any(), any()); } else { // load table from catalog Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/namespaces/ns/tables/table"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/namespaces/ns/tables/table", expectedContextHeaders), eq(LoadTableResponse.class), - eq(expectedContextHeaders), + any(), any()); // refresh loaded table Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/namespaces/ns/tables/table"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/namespaces/ns/tables/table", expectedTableHeaders), eq(LoadTableResponse.class), - eq(expectedTableHeaders), + any(), any()); } } @@ -1426,14 +1306,7 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map contextCredentials = ImmutableMap.of(); SessionCatalog.SessionContext context = @@ -1458,23 +1331,17 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { // call client credentials with no initial auth Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, emptyHeaders), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // verify the first token exchange @@ -1486,12 +1353,14 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + catalogHeaders, + Map.of(), + firstRefreshRequest), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // verify that a second exchange occurs @@ -1508,12 +1377,14 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { "Bearer token-exchange-token:sub=client-credentials-token:sub=catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(secondRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + secondRefreshHeaders, + Map.of(), + secondRefreshRequest), eq(OAuthTokenResponse.class), - eq(secondRefreshHeaders), + any(), any()); }); } @@ -1541,14 +1412,7 @@ public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map contextCredentials = ImmutableMap.of(); SessionCatalog.SessionContext context = @@ -1574,25 +1438,30 @@ public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { assertThat(catalog.tableExists(TableIdentifier.of("ns", "table"))).isFalse(); // call client credentials with no initial auth + Map clientCredentialsRequest = + ImmutableMap.of( + "grant_type", "client_credentials", + "client_id", "catalog", + "client_secret", "secret", + "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + emptyHeaders, + Map.of(), + clientCredentialsRequest), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // verify the first token exchange @@ -1604,12 +1473,14 @@ public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + catalogHeaders, + Map.of(), + firstRefreshRequest), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // use the refreshed context token for table existence check @@ -1619,12 +1490,10 @@ public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { "Bearer token-exchange-token:sub=client-credentials-token:sub=catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher( + HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", refreshedCatalogHeader), any(), any(), - any(), - eq(refreshedCatalogHeader), any()); }); } @@ -1680,24 +1549,25 @@ public void testCatalogExpiredBearerTokenIsRefreshedWithCredential(String oauth2 assertThat(catalog.tableExists(TableIdentifier.of("ns", "table"))).isFalse(); // call client credentials with no initial auth + Map clientCredentialsRequest = + ImmutableMap.of( + "grant_type", "client_credentials", + "client_id", "catalog", + "client_secret", "12345", + "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher( + HTTPMethod.POST, oauth2ServerUri, emptyHeaders, Map.of(), clientCredentialsRequest), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); Map firstRefreshRequest = @@ -1708,12 +1578,14 @@ public void testCatalogExpiredBearerTokenIsRefreshedWithCredential(String oauth2 "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + OAuth2Util.basicAuthHeaders(credential), + Map.of(), + firstRefreshRequest), eq(OAuthTokenResponse.class), - eq(OAuth2Util.basicAuthHeaders(credential)), + any(), any()); // verify that a second exchange occurs @@ -1725,22 +1597,24 @@ public void testCatalogExpiredBearerTokenIsRefreshedWithCredential(String oauth2 "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(secondRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + OAuth2Util.basicAuthHeaders(credential), + Map.of(), + secondRefreshRequest), eq(OAuthTokenResponse.class), - eq(OAuth2Util.basicAuthHeaders(credential)), + any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher( + HTTPMethod.HEAD, + "v1/namespaces/ns/tables/table", + Map.of("Authorization", "Bearer token-exchange-token:sub=" + token)), any(), any(), - eq(ImmutableMap.of("Authorization", "Bearer token-exchange-token:sub=" + token)), any()); } @@ -1767,22 +1641,17 @@ public void testCatalogValidBearerTokenIsNotRefreshed() { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher( + HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", OAuth2Util.authHeaders(token)), any(), any(), - any(), - eq(OAuth2Util.authHeaders(token)), any()); } @@ -1812,14 +1681,7 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map firstRefreshRequest = ImmutableMap.of( @@ -1831,11 +1693,9 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth // simulate that the token expired when it was about to be refreshed Mockito.doThrow(new RuntimeException("token expired")) .when(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + argThat(firstRefreshRequest::equals), eq(OAuthTokenResponse.class), eq(catalogHeaders), any()); @@ -1867,35 +1727,38 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth assertThat(catalog.tableExists(TableIdentifier.of("ns", "table"))).isFalse(); // call client credentials with no initial auth + Map clientCredentialsRequest = + ImmutableMap.of( + "grant_type", "client_credentials", + "client_id", "catalog", + "client_secret", "secret", + "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + emptyHeaders, + Map.of(), + clientCredentialsRequest), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // verify the first token exchange - since an exception is thrown, we're performing // retries Mockito.verify(adapter, times(2)) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + argThat(firstRefreshRequest::equals), eq(OAuthTokenResponse.class), eq(catalogHeaders), any()); @@ -1903,11 +1766,9 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth // here we make sure that the basic auth header is used after token refresh retries // failed Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + argThat(firstRefreshRequest::equals), eq(OAuthTokenResponse.class), eq(basicHeaders), any()); @@ -1919,12 +1780,10 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth "Bearer token-exchange-token:sub=client-credentials-token:sub=catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher( + HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", refreshedCatalogHeader), any(), any(), - eq(refreshedCatalogHeader), any()); }); } @@ -1952,14 +1811,7 @@ public void testCatalogWithCustomTokenScope(String oauth2ServerUri) { Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map contextCredentials = ImmutableMap.of(); SessionCatalog.SessionContext context = @@ -1986,11 +1838,9 @@ public void testCatalogWithCustomTokenScope(String oauth2ServerUri) { () -> { // call client credentials with no initial auth Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - any(), + anyMap(), eq(OAuthTokenResponse.class), eq(emptyHeaders), any()); @@ -1998,12 +1848,9 @@ public void testCatalogWithCustomTokenScope(String oauth2ServerUri) { // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // verify the token exchange uses the right scope @@ -2014,11 +1861,9 @@ public void testCatalogWithCustomTokenScope(String oauth2ServerUri) { "subject_token_type", "urn:ietf:params:oauth:token-type:access_token", "scope", scope); Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + argThat(firstRefreshRequest::equals), eq(OAuthTokenResponse.class), eq(catalogHeaders), any()); @@ -2047,14 +1892,7 @@ public void testCatalogTokenRefreshDisabledWithToken(String oauth2ServerUri) { Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map contextCredentials = ImmutableMap.of(); SessionCatalog.SessionContext context = @@ -2076,12 +1914,9 @@ public void testCatalogTokenRefreshDisabledWithToken(String oauth2ServerUri) { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); } @@ -2122,23 +1957,18 @@ public void testCatalogTokenRefreshDisabledWithCredential(String oauth2ServerUri "scope", "catalog"); Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(fetchTokenFromCredential::equals), + argThat(fetchTokenFromCredential::equals), eq(OAuthTokenResponse.class), eq(ImmutableMap.of()), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); } @@ -2311,20 +2141,14 @@ public void testPaginationForListNamespaces(int numberOfItems) { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", Map.of(), Map.of()), eq(ConfigResponse.class), any(), any()); Mockito.verify(adapter, times(numberOfItems)) .execute( - eq(HTTPMethod.POST), - eq("v1/namespaces"), - any(), - any(), + reqMatcher(HTTPMethod.POST, "v1/namespaces", Map.of(), Map.of()), eq(CreateNamespaceResponse.class), any(), any()); @@ -2375,20 +2199,18 @@ public void testPaginationForListTables(int numberOfItems) { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", Map.of(), Map.of()), eq(ConfigResponse.class), any(), any()); Mockito.verify(adapter, times(numberOfItems)) .execute( - eq(HTTPMethod.POST), - eq(String.format("v1/namespaces/%s/tables", namespaceName)), - any(), - any(), + reqMatcher( + HTTPMethod.POST, + String.format("v1/namespaces/%s/tables", namespaceName), + Map.of(), + Map.of()), eq(LoadTableResponse.class), any(), any()); @@ -2436,21 +2258,27 @@ public void testCleanupUncommitedFilesForCleanableFailures() { .build(); Table table = catalog.loadTable(TABLE); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); Mockito.doThrow(new NotAuthorizedException("not authorized")) .when(adapter) - .post(any(), any(), any(), any(Map.class), any()); + .execute(reqMatcher(HTTPMethod.POST), any(), any(), any()); assertThatThrownBy(() -> catalog.loadTable(TABLE).newFastAppend().appendFile(file).commit()) .isInstanceOf(NotAuthorizedException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(TABLE)), captor.capture(), any(), any(Map.class), any()); // Extract the UpdateTableRequest to determine the path of the manifest list that should be // cleaned up - UpdateTableRequest request = captor.getValue(); - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) request.updates().get(0); - assertThatThrownBy(() -> table.io().newInputFile(addSnapshot.snapshot().manifestListLocation())) - .isInstanceOf(NotFoundException.class); + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(TABLE)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest body = (UpdateTableRequest) req.body(); + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) body.updates().get(0); + assertThatThrownBy( + () -> table.io().newInputFile(addSnapshot.snapshot().manifestListLocation())) + .isInstanceOf(NotFoundException.class); + }); } @Test @@ -2465,21 +2293,26 @@ public void testNoCleanupForNonCleanableExceptions() { catalog.createTable(TABLE, SCHEMA); Table table = catalog.loadTable(TABLE); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); Mockito.doThrow(new ServiceFailureException("some service failure")) .when(adapter) - .post(any(), any(), any(), any(Map.class), any()); + .execute(reqMatcher(HTTPMethod.POST), any(), any(), any()); assertThatThrownBy(() -> catalog.loadTable(TABLE).newFastAppend().appendFile(FILE_A).commit()) .isInstanceOf(ServiceFailureException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(TABLE)), captor.capture(), any(), any(Map.class), any()); // Extract the UpdateTableRequest to determine the path of the manifest list that should still // exist even though the commit failed - UpdateTableRequest request = captor.getValue(); - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) request.updates().get(0); - assertThat(table.io().newInputFile(addSnapshot.snapshot().manifestListLocation()).exists()) - .isTrue(); + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(TABLE)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest body = (UpdateTableRequest) req.body(); + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) body.updates().get(0); + String manifestListLocation = addSnapshot.snapshot().manifestListLocation(); + assertThat(table.io().newInputFile(manifestListLocation).exists()).isTrue(); + }); } @Test @@ -2493,32 +2326,38 @@ public void testCleanupCleanableExceptionsCreate() { catalog.createTable(TABLE, SCHEMA); TableIdentifier newTable = TableIdentifier.of(TABLE.namespace(), "some_table"); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); Mockito.doThrow(new NotAuthorizedException("not authorized")) .when(adapter) - .post(eq(RESOURCE_PATHS.table(newTable)), any(), any(), any(Map.class), any()); + .execute(reqMatcher(HTTPMethod.POST, RESOURCE_PATHS.table(newTable)), any(), any(), any()); Transaction createTableTransaction = catalog.newCreateTableTransaction(newTable, SCHEMA); createTableTransaction.newAppend().appendFile(FILE_A).commit(); assertThatThrownBy(createTableTransaction::commitTransaction) .isInstanceOf(NotAuthorizedException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(newTable)), captor.capture(), any(), any(Map.class), any()); - UpdateTableRequest request = captor.getValue(); - Optional appendSnapshot = - request.updates().stream() - .filter(update -> update instanceof MetadataUpdate.AddSnapshot) - .findFirst(); - - assertThat(appendSnapshot).isPresent(); - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) appendSnapshot.get(); - assertThatThrownBy( - () -> - catalog - .loadTable(TABLE) - .io() - .newInputFile(addSnapshot.snapshot().manifestListLocation())) - .isInstanceOf(NotFoundException.class); + + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(newTable)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest body = (UpdateTableRequest) req.body(); + Optional appendSnapshot = + body.updates().stream() + .filter(update -> update instanceof MetadataUpdate.AddSnapshot) + .findFirst(); + + assertThat(appendSnapshot).isPresent(); + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) appendSnapshot.get(); + assertThatThrownBy( + () -> + catalog + .loadTable(TABLE) + .io() + .newInputFile(addSnapshot.snapshot().manifestListLocation())) + .isInstanceOf(NotFoundException.class); + }); } @Test @@ -2534,29 +2373,32 @@ public void testNoCleanupForNonCleanableCreateTransaction() { TableIdentifier newTable = TableIdentifier.of(TABLE.namespace(), "some_table"); Mockito.doThrow(new ServiceFailureException("some service failure")) .when(adapter) - .post(eq(RESOURCE_PATHS.table(newTable)), any(), any(), any(Map.class), any()); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); + .execute(reqMatcher(HTTPMethod.POST, RESOURCE_PATHS.table(newTable)), any(), any(), any()); + Transaction createTableTransaction = catalog.newCreateTableTransaction(newTable, SCHEMA); createTableTransaction.newAppend().appendFile(FILE_A).commit(); assertThatThrownBy(createTableTransaction::commitTransaction) .isInstanceOf(ServiceFailureException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(newTable)), captor.capture(), any(), any(Map.class), any()); - UpdateTableRequest request = captor.getValue(); - Optional appendSnapshot = - request.updates().stream() - .filter(update -> update instanceof MetadataUpdate.AddSnapshot) - .findFirst(); - assertThat(appendSnapshot).isPresent(); - - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) appendSnapshot.get(); - assertThat( - catalog - .loadTable(TABLE) - .io() - .newInputFile(addSnapshot.snapshot().manifestListLocation()) - .exists()) - .isTrue(); + + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(newTable)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest body = (UpdateTableRequest) req.body(); + Optional appendSnapshot = + body.updates().stream() + .filter(update -> update instanceof MetadataUpdate.AddSnapshot) + .findFirst(); + assertThat(appendSnapshot).isPresent(); + + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) appendSnapshot.get(); + String manifestListLocation = addSnapshot.snapshot().manifestListLocation(); + assertThat(catalog.loadTable(TABLE).io().newInputFile(manifestListLocation).exists()) + .isTrue(); + }); } @Test @@ -2569,32 +2411,35 @@ public void testCleanupCleanableExceptionsReplace() { } catalog.createTable(TABLE, SCHEMA); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); Mockito.doThrow(new NotAuthorizedException("not authorized")) .when(adapter) - .post(eq(RESOURCE_PATHS.table(TABLE)), any(), any(), any(Map.class), any()); + .execute(reqMatcher(HTTPMethod.POST, RESOURCE_PATHS.table(TABLE)), any(), any(), any()); Transaction replaceTableTransaction = catalog.newReplaceTableTransaction(TABLE, SCHEMA, false); replaceTableTransaction.newAppend().appendFile(FILE_A).commit(); assertThatThrownBy(replaceTableTransaction::commitTransaction) .isInstanceOf(NotAuthorizedException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(TABLE)), captor.capture(), any(), any(Map.class), any()); - UpdateTableRequest request = captor.getValue(); - Optional appendSnapshot = - request.updates().stream() - .filter(update -> update instanceof MetadataUpdate.AddSnapshot) - .findFirst(); - - assertThat(appendSnapshot).isPresent(); - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) appendSnapshot.get(); - assertThatThrownBy( - () -> - catalog - .loadTable(TABLE) - .io() - .newInputFile(addSnapshot.snapshot().manifestListLocation())) - .isInstanceOf(NotFoundException.class); + + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(TABLE)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest request = (UpdateTableRequest) req.body(); + Optional appendSnapshot = + request.updates().stream() + .filter(update -> update instanceof MetadataUpdate.AddSnapshot) + .findFirst(); + + assertThat(appendSnapshot).isPresent(); + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) appendSnapshot.get(); + String manifestListLocation = addSnapshot.snapshot().manifestListLocation(); + assertThatThrownBy( + () -> catalog.loadTable(TABLE).io().newInputFile(manifestListLocation)) + .isInstanceOf(NotFoundException.class); + }); } @Test @@ -2609,29 +2454,32 @@ public void testNoCleanupForNonCleanableReplaceTransaction() { catalog.createTable(TABLE, SCHEMA); Mockito.doThrow(new ServiceFailureException("some service failure")) .when(adapter) - .post(eq(RESOURCE_PATHS.table(TABLE)), any(), any(), any(Map.class), any()); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); + .execute(reqMatcher(HTTPMethod.POST, RESOURCE_PATHS.table(TABLE)), any(), any(), any()); + Transaction replaceTableTransaction = catalog.newReplaceTableTransaction(TABLE, SCHEMA, false); replaceTableTransaction.newAppend().appendFile(FILE_A).commit(); assertThatThrownBy(replaceTableTransaction::commitTransaction) .isInstanceOf(ServiceFailureException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(TABLE)), captor.capture(), any(), any(Map.class), any()); - UpdateTableRequest request = captor.getValue(); - Optional appendSnapshot = - request.updates().stream() - .filter(update -> update instanceof MetadataUpdate.AddSnapshot) - .findFirst(); - assertThat(appendSnapshot).isPresent(); - - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) appendSnapshot.get(); - assertThat( - catalog - .loadTable(TABLE) - .io() - .newInputFile(addSnapshot.snapshot().manifestListLocation()) - .exists()) - .isTrue(); + + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(TABLE)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest request = (UpdateTableRequest) req.body(); + Optional appendSnapshot = + request.updates().stream() + .filter(update -> update instanceof MetadataUpdate.AddSnapshot) + .findFirst(); + assertThat(appendSnapshot).isPresent(); + + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) appendSnapshot.get(); + String manifestListLocation = addSnapshot.snapshot().manifestListLocation(); + assertThat(catalog.loadTable(TABLE).io().newInputFile(manifestListLocation).exists()) + .isTrue(); + }); } @Test @@ -2645,19 +2493,13 @@ public void testNamespaceExistsViaHEADRequest() { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", Map.of(), Map.of()), eq(ConfigResponse.class), any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/non-existing"), - any(), - any(), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/non-existing", Map.of(), Map.of()), any(), any(), any()); @@ -2672,4 +2514,51 @@ private RESTCatalog catalog(RESTCatalogAdapter adapter) { CatalogProperties.FILE_IO_IMPL, "org.apache.iceberg.inmemory.InMemoryFileIO")); return catalog; } + + static HTTPRequest reqMatcher(HTTPMethod method) { + return argThat(req -> req.method() == method); + } + + static HTTPRequest reqMatcher(HTTPMethod method, String path) { + return argThat(req -> req.method() == method && req.path().equals(path)); + } + + static HTTPRequest reqMatcher(HTTPMethod method, String path, Map headers) { + return argThat( + req -> + req.method() == method + && req.path().equals(path) + && req.headers().equals(HTTPHeaders.of(headers))); + } + + static HTTPRequest reqMatcher( + HTTPMethod method, String path, Map headers, Map parameters) { + return argThat( + req -> + req.method() == method + && req.path().equals(path) + && req.headers().equals(HTTPHeaders.of(headers)) + && req.queryParameters().equals(parameters)); + } + + static HTTPRequest reqMatcher( + HTTPMethod method, + String path, + Map headers, + Map parameters, + Object body) { + return argThat( + req -> + req.method() == method + && req.path().equals(path) + && req.headers().equals(HTTPHeaders.of(headers)) + && req.queryParameters().equals(parameters) + && Objects.equals(req.body(), body)); + } + + private static List allRequests(RESTCatalogAdapter adapter) { + ArgumentCaptor captor = ArgumentCaptor.forClass(HTTPRequest.class); + verify(adapter, atLeastOnce()).execute(captor.capture(), any(), any(), any()); + return captor.getAllValues(); + } } diff --git a/core/src/test/java/org/apache/iceberg/rest/TestRESTViewCatalog.java b/core/src/test/java/org/apache/iceberg/rest/TestRESTViewCatalog.java index cfc22ca767f6..8dfbf0df6dd7 100644 --- a/core/src/test/java/org/apache/iceberg/rest/TestRESTViewCatalog.java +++ b/core/src/test/java/org/apache/iceberg/rest/TestRESTViewCatalog.java @@ -18,6 +18,7 @@ */ package org.apache.iceberg.rest; +import static org.apache.iceberg.rest.TestRESTCatalog.reqMatcher; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -38,7 +39,7 @@ import org.apache.iceberg.catalog.TableIdentifier; import org.apache.iceberg.inmemory.InMemoryCatalog; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; -import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; import org.apache.iceberg.rest.responses.ConfigResponse; import org.apache.iceberg.rest.responses.ErrorResponse; import org.apache.iceberg.rest.responses.ListTablesResponse; @@ -82,17 +83,13 @@ public void createCatalog() throws Exception { new RESTCatalogAdapter(backendCatalog) { @Override public T execute( - HTTPMethod method, - String path, - Map queryParams, - Object body, + HTTPRequest request, Class responseType, - Map headers, - Consumer errorHandler) { - Object request = roundTripSerialize(body, "request"); - T response = - super.execute( - method, path, queryParams, request, responseType, headers, errorHandler); + Consumer errorHandler, + Consumer> responseHeaders) { + Object body = roundTripSerialize(request.body(), "request"); + HTTPRequest req = ImmutableHTTPRequest.builder().from(request).body(body).build(); + T response = super.execute(req, responseType, errorHandler, responseHeaders); T responseAfterSerialization = roundTripSerialize(response, "response"); return responseAfterSerialization; } @@ -187,21 +184,11 @@ public void testPaginationForListViews(int numberOfItems) { assertThat(views).hasSize(numberOfItems); Mockito.verify(adapter) - .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), - eq(ConfigResponse.class), - any(), - any()); + .execute(reqMatcher(HTTPMethod.GET, "v1/config"), eq(ConfigResponse.class), any(), any()); Mockito.verify(adapter, times(numberOfItems)) .execute( - eq(HTTPMethod.POST), - eq(String.format("v1/namespaces/%s/views", namespaceName)), - any(), - any(), + reqMatcher(HTTPMethod.POST, String.format("v1/namespaces/%s/views", namespaceName)), eq(LoadViewResponse.class), any(), any()); @@ -244,19 +231,13 @@ public void viewExistsViaHEADRequest() { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", Map.of(), Map.of()), eq(ConfigResponse.class), any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/views/view"), - any(), - any(), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/views/view", Map.of(), Map.of()), any(), any(), any()); diff --git a/gcp/src/main/java/org/apache/iceberg/gcp/gcs/OAuth2RefreshCredentialsHandler.java b/gcp/src/main/java/org/apache/iceberg/gcp/gcs/OAuth2RefreshCredentialsHandler.java index dfaa502c6005..6f7807e8dd6e 100644 --- a/gcp/src/main/java/org/apache/iceberg/gcp/gcs/OAuth2RefreshCredentialsHandler.java +++ b/gcp/src/main/java/org/apache/iceberg/gcp/gcs/OAuth2RefreshCredentialsHandler.java @@ -29,7 +29,10 @@ import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.rest.ErrorHandlers; import org.apache.iceberg.rest.HTTPClient; +import org.apache.iceberg.rest.HTTPHeaders; import org.apache.iceberg.rest.RESTClient; +import org.apache.iceberg.rest.auth.AuthSession; +import org.apache.iceberg.rest.auth.DefaultAuthSession; import org.apache.iceberg.rest.auth.OAuth2Properties; import org.apache.iceberg.rest.auth.OAuth2Util; import org.apache.iceberg.rest.credentials.Credential; @@ -38,12 +41,16 @@ public class OAuth2RefreshCredentialsHandler implements OAuth2CredentialsWithRefresh.OAuth2RefreshHandler { private final Map properties; + private final AuthSession authSession; private OAuth2RefreshCredentialsHandler(Map properties) { Preconditions.checkArgument( null != properties.get(GCPProperties.GCS_OAUTH2_REFRESH_CREDENTIALS_ENDPOINT), "Invalid credentials endpoint: null"); this.properties = properties; + this.authSession = + DefaultAuthSession.of( + HTTPHeaders.of(OAuth2Util.authHeaders(properties.get(OAuth2Properties.TOKEN)))); } @SuppressWarnings("JavaUtilDate") // GCP API uses java.util.Date @@ -56,7 +63,7 @@ public AccessToken refreshAccessToken() { properties.get(GCPProperties.GCS_OAUTH2_REFRESH_CREDENTIALS_ENDPOINT), null, LoadCredentialsResponse.class, - OAuth2Util.authHeaders(properties.get(OAuth2Properties.TOKEN)), + Map.of(), ErrorHandlers.defaultErrorHandler()); } catch (IOException e) { throw new RuntimeException(e); @@ -95,6 +102,7 @@ public static OAuth2RefreshCredentialsHandler create(Map propert private RESTClient httpClient() { return HTTPClient.builder(properties) .uri(properties.get(GCPProperties.GCS_OAUTH2_REFRESH_CREDENTIALS_ENDPOINT)) + .withAuthSession(authSession) .build(); } }