diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/S3IntegrationTestBase.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/S3IntegrationTestBase.java index 63dcf2ddc88f..0bf0277ea242 100644 --- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/S3IntegrationTestBase.java +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/S3IntegrationTestBase.java @@ -93,6 +93,7 @@ private static void createBucket(String bucketName, int retryCount) { .build()) .build()); } catch (S3Exception e) { + e.printStackTrace(); System.err.println("Error attempting to create bucket: " + bucketName); if (e.awsErrorDetails().errorCode().equals("BucketAlreadyOwnedByYou")) { System.err.printf("%s bucket already exists, likely leaked by a previous run\n", bucketName); diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClient.java index b427a24aad1f..42b12954d574 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClient.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClient.java @@ -15,29 +15,38 @@ package software.amazon.awssdk.services.s3.internal.crossregion; +import static software.amazon.awssdk.services.s3.internal.crossregion.utils.CrossRegionUtils.getBucketRegionFromException; +import static software.amazon.awssdk.services.s3.internal.crossregion.utils.CrossRegionUtils.isS3RedirectException; +import static software.amazon.awssdk.services.s3.internal.crossregion.utils.CrossRegionUtils.requestWithDecoratedEndpointProvider; import static software.amazon.awssdk.services.s3.internal.crossregion.utils.CrossRegionUtils.updateUserAgentInConfig; +import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiConsumer; import java.util.function.Function; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; -import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.DelegatingS3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClient; -import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; -import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; +import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.services.s3.model.S3Request; +import software.amazon.awssdk.utils.CompletableFutureUtils; @SdkInternalApi public final class S3CrossRegionAsyncClient extends DelegatingS3AsyncClient { + + private final Map bucketToRegionCache = new ConcurrentHashMap<>(); + public S3CrossRegionAsyncClient(S3AsyncClient s3Client) { super(s3Client); } @Override - protected CompletableFuture - invokeOperation(T request, Function> operation) { + protected CompletableFuture invokeOperation( + T request, Function> operation) { Optional bucket = request.getValueForField("Bucket", String.class); @@ -47,53 +56,94 @@ public S3CrossRegionAsyncClient(S3AsyncClient s3Client) { if (!bucket.isPresent()) { return operation.apply(userAgentUpdatedRequest); } - - return operation.apply(requestWithDecoratedEndpointProvider(userAgentUpdatedRequest, bucket.get())) - .whenComplete((r, t) -> handleOperationFailure(t, bucket.get())); + String bucketName = bucket.get(); + + CompletableFuture returnFuture = new CompletableFuture<>(); + CompletableFuture apiOperationFuture = bucketToRegionCache.containsKey(bucketName) ? + operation.apply( + requestWithDecoratedEndpointProvider( + userAgentUpdatedRequest, + () -> bucketToRegionCache.get(bucketName), + serviceClientConfiguration().endpointProvider().get() + ) + ) : + operation.apply(userAgentUpdatedRequest); + + apiOperationFuture.whenComplete(redirectToCrossRegionIfRedirectException(operation, + userAgentUpdatedRequest, + bucketName, + returnFuture)); + return returnFuture; } - private void handleOperationFailure(Throwable t, String bucket) { - //TODO: handle failure case + private BiConsumer redirectToCrossRegionIfRedirectException( + Function> operation, + T userAgentUpdatedRequest, String bucketName, + CompletableFuture returnFuture) { + + return (response, throwable) -> { + if (throwable != null) { + if (isS3RedirectException(throwable)) { + bucketToRegionCache.remove(bucketName); + requestWithCrossRegion(userAgentUpdatedRequest, operation, bucketName, returnFuture, throwable); + } else { + returnFuture.completeExceptionally(throwable); + } + } else { + returnFuture.complete(response); + } + }; } - //Cannot avoid unchecked cast without upstream changes to supply builder function - @SuppressWarnings("unchecked") - private T requestWithDecoratedEndpointProvider(T request, String bucket) { - return (T) request.toBuilder() - .overrideConfiguration(getOrCreateConfigWithEndpointProvider(request, bucket)) - .build(); + private void requestWithCrossRegion(T request, + Function> operation, + String bucketName, + CompletableFuture returnFuture, + Throwable throwable) { + + Optional bucketRegionFromException = getBucketRegionFromException((S3Exception) throwable.getCause()); + if (bucketRegionFromException.isPresent()) { + sendRequestWithRightRegion(request, operation, bucketName, returnFuture, bucketRegionFromException); + } else { + fetchRegionAndSendRequest(request, operation, bucketName, returnFuture); + } } - //TODO: optimize shared sync/async code - private AwsRequestOverrideConfiguration getOrCreateConfigWithEndpointProvider(S3Request request, String bucket) { - AwsRequestOverrideConfiguration requestOverrideConfig = - request.overrideConfiguration().orElseGet(() -> AwsRequestOverrideConfiguration.builder().build()); - - S3EndpointProvider delegateEndpointProvider = (S3EndpointProvider) - requestOverrideConfig.endpointProvider().orElseGet(() -> serviceClientConfiguration().endpointProvider().get()); - - return requestOverrideConfig.toBuilder() - .endpointProvider(BucketEndpointProvider.create(delegateEndpointProvider, bucket)) - .build(); + private void fetchRegionAndSendRequest(T request, + Function> operation, + String bucketName, + CompletableFuture returnFuture) { + // // TODO: will fix the casts with separate PR + ((S3AsyncClient) delegate()).headBucket(b -> b.bucket(bucketName)).whenComplete((response, + throwable) -> { + if (throwable != null) { + if (isS3RedirectException(throwable)) { + bucketToRegionCache.remove(bucketName); + Optional bucketRegion = getBucketRegionFromException((S3Exception) throwable.getCause()); + if (bucketRegion.isPresent()) { + sendRequestWithRightRegion(request, operation, bucketName, returnFuture, bucketRegion); + } else { + returnFuture.completeExceptionally(throwable); + } + } else { + returnFuture.completeExceptionally(throwable); + } + } + }); } - //TODO: add cross region logic - static final class BucketEndpointProvider implements S3EndpointProvider { - private final S3EndpointProvider delegate; - private final String bucket; - - private BucketEndpointProvider(S3EndpointProvider delegate, String bucket) { - this.delegate = delegate; - this.bucket = bucket; - } - - public static BucketEndpointProvider create(S3EndpointProvider delegate, String bucket) { - return new BucketEndpointProvider(delegate, bucket); - } - - @Override - public CompletableFuture resolveEndpoint(S3EndpointParams endpointParams) { - return delegate.resolveEndpoint(endpointParams); - } + private void sendRequestWithRightRegion(T request, + Function> operation, + String bucketName, + CompletableFuture returnFuture, + Optional bucketRegionFromException) { + String region = bucketRegionFromException.get(); + bucketToRegionCache.put(bucketName, Region.of(region)); + CompletableFuture newFuture = operation.apply( + requestWithDecoratedEndpointProvider(request, + () -> Region.of(region), + serviceClientConfiguration().endpointProvider().get())); + CompletableFutureUtils.forwardResultTo(newFuture, returnFuture); + CompletableFutureUtils.forwardExceptionTo(returnFuture, newFuture); } -} +} \ No newline at end of file diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClient.java index e472da136370..10ac10af5816 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClient.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClient.java @@ -15,85 +15,93 @@ package software.amazon.awssdk.services.s3.internal.crossregion; +import static software.amazon.awssdk.services.s3.internal.crossregion.utils.CrossRegionUtils.getBucketRegionFromException; +import static software.amazon.awssdk.services.s3.internal.crossregion.utils.CrossRegionUtils.isS3RedirectException; +import static software.amazon.awssdk.services.s3.internal.crossregion.utils.CrossRegionUtils.requestWithDecoratedEndpointProvider; import static software.amazon.awssdk.services.s3.internal.crossregion.utils.CrossRegionUtils.updateUserAgentInConfig; +import java.util.Map; import java.util.Optional; -import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; -import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.DelegatingS3Client; import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; -import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; +import software.amazon.awssdk.services.s3.model.HeadBucketRequest; +import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.services.s3.model.S3Request; +/** + * Decorator S3 Sync client that will fetch the region name whenever there is Redirect 301 error due to cross region bucket + * access. + */ @SdkInternalApi public final class S3CrossRegionSyncClient extends DelegatingS3Client { + + private final Map bucketToRegionCache = new ConcurrentHashMap<>(); + public S3CrossRegionSyncClient(S3Client s3Client) { super(s3Client); } + private static Optional bucketNameFromRequest(T request) { + return request.getValueForField("Bucket", String.class); + } + @Override protected ReturnT invokeOperation(T request, Function operation) { - Optional bucket = request.getValueForField("Bucket", String.class); + Optional bucketRequest = bucketNameFromRequest(request); AwsRequestOverrideConfiguration overrideConfiguration = updateUserAgentInConfig(request); T userAgentUpdatedRequest = (T) request.toBuilder().overrideConfiguration(overrideConfiguration).build(); - if (bucket.isPresent()) { - try { - return operation.apply(requestWithDecoratedEndpointProvider(userAgentUpdatedRequest, bucket.get())); - } catch (Exception e) { - handleOperationFailure(e, bucket.get()); + + if (!bucketRequest.isPresent()) { + return operation.apply(userAgentUpdatedRequest); + } + String bucketName = bucketRequest.get(); + try { + if (bucketToRegionCache.containsKey(bucketName)) { + return operation.apply( + requestWithDecoratedEndpointProvider(userAgentUpdatedRequest, + () -> bucketToRegionCache.get(bucketName), + serviceClientConfiguration().endpointProvider().get())); + } + return operation.apply(userAgentUpdatedRequest); + } catch (S3Exception exception) { + if (isS3RedirectException(exception)) { + updateCacheFromRedirectException(exception, bucketName); + return operation.apply( + requestWithDecoratedEndpointProvider( + userAgentUpdatedRequest, + () -> bucketToRegionCache.computeIfAbsent(bucketName, this::fetchBucketRegion), + serviceClientConfiguration().endpointProvider().get())); } + throw exception; } - - return operation.apply(userAgentUpdatedRequest); - } - - private void handleOperationFailure(Throwable t, String bucket) { - //TODO: handle failure case } - @SuppressWarnings("unchecked") - private T requestWithDecoratedEndpointProvider(T request, String bucket) { - return (T) request.toBuilder() - .overrideConfiguration(getOrCreateConfigWithEndpointProvider(request, bucket)) - .build(); + private void updateCacheFromRedirectException(S3Exception exception, String bucketName) { + Optional regionStr = getBucketRegionFromException(exception); + // If redirected, clear previous values due to region change. + bucketToRegionCache.remove(bucketName); + regionStr.ifPresent(region -> bucketToRegionCache.put(bucketName, Region.of(region))); } - //TODO: optimize shared sync/async code - private AwsRequestOverrideConfiguration getOrCreateConfigWithEndpointProvider(S3Request request, String bucket) { - AwsRequestOverrideConfiguration requestOverrideConfig = - request.overrideConfiguration().orElseGet(() -> AwsRequestOverrideConfiguration.builder().build()); - - S3EndpointProvider delegateEndpointProvider = (S3EndpointProvider) - requestOverrideConfig.endpointProvider().orElseGet(() -> serviceClientConfiguration().endpointProvider().get()); - - return requestOverrideConfig.toBuilder() - .endpointProvider(BucketEndpointProvider.create(delegateEndpointProvider, bucket)) - .build(); - } - - static final class BucketEndpointProvider implements S3EndpointProvider { - private final S3EndpointProvider delegate; - private final String bucket; - - private BucketEndpointProvider(S3EndpointProvider delegate, String bucket) { - this.delegate = delegate; - this.bucket = bucket; + private Region fetchBucketRegion(String bucketName) { + try { + ((S3Client) delegate()).headBucket(HeadBucketRequest.builder().bucket(bucketName).build()); + } catch (S3Exception exception) { + if (isS3RedirectException(exception)) { + return Region.of(getBucketRegionFromException(exception).orElseThrow(() -> exception)); + } + throw exception; } + return null; + } - public static BucketEndpointProvider create(S3EndpointProvider delegate, String bucket) { - return new BucketEndpointProvider(delegate, bucket); - } - @Override - public CompletableFuture resolveEndpoint(S3EndpointParams endpointParams) { - return delegate.resolveEndpoint(endpointParams); - } - } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/endpointprovider/BucketEndpointProvider.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/endpointprovider/BucketEndpointProvider.java new file mode 100644 index 000000000000..ce89341753b5 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/endpointprovider/BucketEndpointProvider.java @@ -0,0 +1,50 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.crossregion.endpointprovider; + +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; + +/** + * Decorator S3EndpointProvider which updates the region with the one that is supplied during its instantiation. + */ +@SdkInternalApi +public class BucketEndpointProvider implements S3EndpointProvider { + private final S3EndpointProvider delegateEndPointProvider; + private final Supplier regionSupplier; + + private BucketEndpointProvider(S3EndpointProvider delegateEndPointProvider, Supplier regionSupplier) { + this.delegateEndPointProvider = delegateEndPointProvider; + this.regionSupplier = regionSupplier; + } + + public static BucketEndpointProvider create(S3EndpointProvider delegateEndPointProvider, Supplier regionSupplier) { + return new BucketEndpointProvider(delegateEndPointProvider, regionSupplier); + } + + @Override + public CompletableFuture resolveEndpoint(S3EndpointParams endpointParams) { + Region crossRegion = regionSupplier.get(); + return delegateEndPointProvider.resolveEndpoint( + crossRegion != null ? endpointParams.copy(c -> c.region(crossRegion)) : endpointParams); + } +} + diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/utils/CrossRegionUtils.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/utils/CrossRegionUtils.java index 83a10438c49b..5c4413997671 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/utils/CrossRegionUtils.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/utils/CrossRegionUtils.java @@ -16,20 +16,61 @@ package software.amazon.awssdk.services.s3.internal.crossregion.utils; +import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; +import java.util.function.Supplier; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.core.ApiName; +import software.amazon.awssdk.endpoints.EndpointProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; +import software.amazon.awssdk.services.s3.internal.crossregion.endpointprovider.BucketEndpointProvider; +import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.services.s3.model.S3Request; @SdkInternalApi public final class CrossRegionUtils { + public static final int REDIRECT_STATUS_CODE = 301; + public static final String AMZ_BUCKET_REGION_HEADER = "x-amz-bucket-region"; private static final ApiName API_NAME = ApiName.builder().version("cross-region").name("hll").build(); private static final Consumer USER_AGENT_APPLIER = b -> b.addApiName(API_NAME); + private CrossRegionUtils() { } + public static Optional getBucketRegionFromException(S3Exception exception) { + return exception.awsErrorDetails() + .sdkHttpResponse() + .firstMatchingHeader(AMZ_BUCKET_REGION_HEADER); + } + + public static boolean isS3RedirectException(Throwable exception) { + Throwable exceptionToBeChecked = exception instanceof CompletionException ? exception.getCause() : exception ; + return exceptionToBeChecked instanceof S3Exception + && ((S3Exception) exceptionToBeChecked).statusCode() == REDIRECT_STATUS_CODE; + } + + + @SuppressWarnings("unchecked") + public static T requestWithDecoratedEndpointProvider(T request, Supplier regionSupplier, + EndpointProvider clientEndpointProvider) { + AwsRequestOverrideConfiguration requestOverrideConfig = + request.overrideConfiguration().orElseGet(() -> AwsRequestOverrideConfiguration.builder().build()); + + S3EndpointProvider delegateEndpointProvider = (S3EndpointProvider) requestOverrideConfig.endpointProvider() + .orElse(clientEndpointProvider); + return (T) request.toBuilder() + .overrideConfiguration( + requestOverrideConfig.toBuilder() + .endpointProvider( + BucketEndpointProvider.create(delegateEndpointProvider, regionSupplier)) + .build()) + .build(); + } + public static AwsRequestOverrideConfiguration updateUserAgentInConfig(T request) { AwsRequestOverrideConfiguration overrideConfiguration = request.overrideConfiguration().map(c -> c.toBuilder() diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClientRedirectTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClientRedirectTest.java new file mode 100644 index 000000000000..b3859a682c03 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClientRedirectTest.java @@ -0,0 +1,143 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.crossregion; + +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.function.Consumer; +import org.junit.jupiter.api.BeforeEach; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.HeadBucketRequest; +import software.amazon.awssdk.services.s3.model.ListBucketsRequest; +import software.amazon.awssdk.services.s3.model.ListBucketsResponse; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.utils.CompletableFutureUtils; + +public class S3CrossRegionAsyncClientRedirectTest extends S3DecoratorRedirectTestBase { + private static S3AsyncClient mockDelegateAsyncClient; + private S3AsyncClient decoratedS3AsyncClient; + + @BeforeEach + public void setup() { + mockDelegateAsyncClient = Mockito.mock(S3AsyncClient.class); + decoratedS3AsyncClient = new S3CrossRegionAsyncClient(mockDelegateAsyncClient); + } + + @Override + protected void stubRedirectSuccessSuccess() { + when(mockDelegateAsyncClient.listObjects(any(ListObjectsRequest.class))) + .thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(301, CROSS_REGION.id(), null, null)))) + .thenReturn(CompletableFuture.completedFuture(ListObjectsResponse.builder().contents(S3_OBJECTS).build())) + .thenReturn(CompletableFuture.completedFuture(ListObjectsResponse.builder().contents(S3_OBJECTS).build())); + } + + @Override + protected ListObjectsResponse apiCallToService() throws Throwable { + try{ + return decoratedS3AsyncClient.listObjects(i -> i.bucket(CROSS_REGION_BUCKET)).join(); + }catch (CompletionException exception){ + throw exception.getCause(); + } + } + + @Override + protected void verifyTheApiServiceCall(int times, ArgumentCaptor requestArgumentCaptor) { + verify(mockDelegateAsyncClient, times(times)).listObjects(requestArgumentCaptor.capture()); + } + + @Override + protected void stubServiceClientConfiguration() { + when(mockDelegateAsyncClient.serviceClientConfiguration()).thenReturn(CONFIGURED_ENDPOINT_PROVIDER); + } + + @Override + protected void stubClientAPICallWithFirstRedirectThenSuccessWithRegionInErrorResponse() { + when(mockDelegateAsyncClient.listObjects(any(ListObjectsRequest.class))) + .thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(301, CROSS_REGION.id(), null, + null)))) + .thenReturn(CompletableFuture.completedFuture(ListObjectsResponse.builder().contents(S3_OBJECTS).build() + )); + } + + @Override + protected void verifyNoBucketApiCall(int times, ArgumentCaptor requestArgumentCaptor) { + verify(mockDelegateAsyncClient, times(times)).listBuckets(requestArgumentCaptor.capture()); + } + + @Override + protected ListBucketsResponse noBucketCallToService() throws Throwable { + return decoratedS3AsyncClient.listBuckets(ListBucketsRequest.builder().build()).join(); + } + + @Override + protected void stubApiWithNoBucketField() { + when(mockDelegateAsyncClient.listBuckets(any(ListBucketsRequest.class))) + .thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(301, CROSS_REGION.id(), null, + "Redirect")))) + .thenReturn(CompletableFuture.completedFuture(ListBucketsResponse.builder().build() + )); + } + + @Override + protected void stubHeadBucketRedirect() { + when(mockDelegateAsyncClient.headBucket(any(HeadBucketRequest.class))) + .thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(301,CROSS_REGION.id(), null, null)))); + when(mockDelegateAsyncClient.headBucket(any(Consumer.class))) + .thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(301,CROSS_REGION.id(), null, null)))); + } + + @Override + protected void stubRedirectWithNoRegionAndThenSuccess() { + when(mockDelegateAsyncClient.listObjects(any(ListObjectsRequest.class))) + .thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(301, null, null, null)))) + .thenReturn(CompletableFuture.completedFuture(ListObjectsResponse.builder().contents(S3_OBJECTS).build())) + .thenReturn(CompletableFuture.completedFuture(ListObjectsResponse.builder().contents(S3_OBJECTS).build())); + } + + @Override + protected void stubRedirectThenError() { + when(mockDelegateAsyncClient.listObjects(any(ListObjectsRequest.class))) + .thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(301, CROSS_REGION.id(), null, + null)))) + .thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(400, null, + "InvalidArgument", "Invalid id")))); + } + + @Override + protected void verifyHeadBucketServiceCall(int times) { + verify(mockDelegateAsyncClient, times(times)).headBucket(any(Consumer.class)); + } + + @Override + protected void verifyNoBucketCall() { + assertThatExceptionOfType(CompletionException.class) + .isThrownBy( + () -> noBucketCallToService()) + + .withCauseInstanceOf(S3Exception.class) + .withMessage("software.amazon.awssdk.services.s3.model.S3Exception: Redirect (Service: S3, Status Code: 301, Request ID: 1, Extended Request ID: A1)"); + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClientTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClientTest.java index 01fdbbff47c8..59a0bea1bbe7 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClientTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClientTest.java @@ -16,12 +16,29 @@ package software.amazon.awssdk.services.s3.internal.crossregion; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static software.amazon.awssdk.services.s3.internal.crossregion.S3DecoratorRedirectTestBase.CHANGED_CROSS_REGION; +import static software.amazon.awssdk.services.s3.internal.crossregion.S3DecoratorRedirectTestBase.CROSS_REGION; +import static software.amazon.awssdk.services.s3.internal.crossregion.S3DecoratorRedirectTestBase.OVERRIDE_CONFIGURED_REGION; +import static software.amazon.awssdk.services.s3.internal.crossregion.S3DecoratorRedirectTestBase.X_AMZ_BUCKET_REGION; import java.net.URI; +import java.util.Arrays; +import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import software.amazon.awssdk.core.ResponseBytes; import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; @@ -29,14 +46,24 @@ import software.amazon.awssdk.endpoints.EndpointProvider; import software.amazon.awssdk.http.AbortableInputStream; import software.amazon.awssdk.http.HttpExecuteResponse; +import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClientBuilder; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.endpoints.internal.DefaultS3EndpointProvider; +import software.amazon.awssdk.services.s3.internal.crossregion.endpointprovider.BucketEndpointProvider; import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.ListObjectsV2Response; +import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Publisher; import software.amazon.awssdk.testutils.service.http.MockAsyncHttpClient; import software.amazon.awssdk.utils.StringInputStream; +import software.amazon.awssdk.utils.StringUtils; class S3CrossRegionAsyncClientTest { @@ -44,32 +71,45 @@ class S3CrossRegionAsyncClientTest { private static final String BUCKET = "bucket"; private static final String KEY = "key"; private static final String TOKEN = "token"; - - private final MockAsyncHttpClient mockAsyncHttpClient = new MockAsyncHttpClient(); + private MockAsyncHttpClient mockAsyncHttpClient ; private CaptureInterceptor captureInterceptor; private S3AsyncClient s3Client; @BeforeEach void before() { - mockAsyncHttpClient.stubNextResponse( - HttpExecuteResponse.builder() - .response(SdkHttpResponse.builder().statusCode(200).build()) - .responseBody(AbortableInputStream.create(new StringInputStream(RESPONSE))) - .build()); - + mockAsyncHttpClient = new MockAsyncHttpClient(); captureInterceptor = new CaptureInterceptor(); s3Client = clientBuilder().build(); } - @Test - void standardOp_crossRegionClient_noOverrideConfig_SuccessfullyIntercepts() { + public static Stream stubResponses() { + Consumer redirectStubConsumer = mockSyncHttpClient -> + mockSyncHttpClient.stubResponses(customHttpResponse(301, CROSS_REGION.id()), successHttpResponse()); + + Consumer successStubConsumer = mockSyncHttpClient -> + mockSyncHttpClient.stubResponses(successHttpResponse(), successHttpResponse()); + + return Stream.of( + Arguments.of(redirectStubConsumer, BucketEndpointProvider.class), + Arguments.of(successStubConsumer, DefaultS3EndpointProvider.class) + ); + } + + @ParameterizedTest + @MethodSource("stubResponses") + void standardOp_crossRegionClient_noOverrideConfig_SuccessfullyIntercepts(Consumer stubConsumer, + Class endpointProviderType) { + stubConsumer.accept(mockAsyncHttpClient); S3AsyncClient crossRegionClient = new S3CrossRegionAsyncClient(s3Client); crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join(); - assertThat(captureInterceptor.endpointProvider).isInstanceOf(S3CrossRegionAsyncClient.BucketEndpointProvider.class); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(endpointProviderType); } - @Test - void standardOp_crossRegionClient_existingOverrideConfig_SuccessfullyIntercepts() { + @ParameterizedTest + @MethodSource("stubResponses") + void standardOp_crossRegionClient_existingOverrideConfig_SuccessfullyIntercepts(Consumer stubConsumer, + Class endpointProviderType) { + stubConsumer.accept(mockAsyncHttpClient); S3AsyncClient crossRegionClient = new S3CrossRegionAsyncClient(s3Client); GetObjectRequest request = GetObjectRequest.builder() .bucket(BUCKET) @@ -77,29 +117,206 @@ void standardOp_crossRegionClient_existingOverrideConfig_SuccessfullyIntercepts( .overrideConfiguration(o -> o.putHeader("someheader", "somevalue")) .build(); crossRegionClient.getObject(request, AsyncResponseTransformer.toBytes()).join(); - assertThat(captureInterceptor.endpointProvider).isInstanceOf(S3CrossRegionAsyncClient.BucketEndpointProvider.class); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(endpointProviderType); assertThat(mockAsyncHttpClient.getLastRequest().headers().get("someheader")).isNotNull(); } - @Test - void paginatedOp_crossRegionClient_DoesNotIntercept() throws Exception { + @ParameterizedTest + @MethodSource("stubResponses") + void paginatedOp_crossRegionClient_DoesIntercept(Consumer stubConsumer, + Class endpointProviderType) throws Exception { + stubConsumer.accept(mockAsyncHttpClient); S3AsyncClient crossRegionClient = new S3CrossRegionAsyncClient(s3Client); ListObjectsV2Publisher publisher = crossRegionClient.listObjectsV2Paginator(r -> r.bucket(BUCKET).continuationToken(TOKEN).build()); CompletableFuture future = publisher.subscribe(ListObjectsV2Response::contents); future.get(); - assertThat(captureInterceptor.endpointProvider).isInstanceOf(S3CrossRegionAsyncClient.BucketEndpointProvider.class); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(endpointProviderType); } - @Test - void crossRegionClient_createdWithWrapping_SuccessfullyIntercepts() { + @ParameterizedTest + @MethodSource("stubResponses") + void crossRegionClient_createdWithWrapping_SuccessfullyIntercepts(Consumer stubConsumer, + Class endpointProviderType) { + stubConsumer.accept(mockAsyncHttpClient); S3AsyncClient crossRegionClient = clientBuilder().serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join(); - assertThat(captureInterceptor.endpointProvider).isInstanceOf(S3CrossRegionAsyncClient.BucketEndpointProvider.class); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(endpointProviderType); + } + + @Test + void crossRegionClient_CallsHeadObject_when_regionNameNotPresentInFallBackCall(){ + mockAsyncHttpClient.reset(); + mockAsyncHttpClient.stubResponses(customHttpResponse(301, null), + customHttpResponse(301, CROSS_REGION.id()), + successHttpResponse(), successHttpResponse()); + S3AsyncClient crossRegionClient = + clientBuilder().endpointOverride(null).region(OVERRIDE_CONFIGURED_REGION).serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join(); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(BucketEndpointProvider.class); + + List requests = mockAsyncHttpClient.getRequests(); + assertThat(requests).hasSize(3); + + assertThat(requests.stream().map(req -> req.host().substring(10,req.host().length() - 14 )).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(OVERRIDE_CONFIGURED_REGION.id(), + OVERRIDE_CONFIGURED_REGION.id(), + CROSS_REGION.id())); + + assertThat(requests.stream().map(req -> req.method()).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(SdkHttpMethod.GET, + SdkHttpMethod.HEAD, + SdkHttpMethod.GET)); + + // Resetting the mock client to capture the new API request for second S3 Call. + mockAsyncHttpClient.reset(); + mockAsyncHttpClient.stubResponses(successHttpResponse(), successHttpResponse()); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join(); + List postCacheRequests = mockAsyncHttpClient.getRequests(); + + assertThat(postCacheRequests.stream() + .map(req -> req.host().substring(10,req.host().length() - 14 )) + .collect(Collectors.toList())) + .isEqualTo(Arrays.asList(CROSS_REGION.id())); + assertThat(postCacheRequests.stream().map(req -> req.method()).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(SdkHttpMethod.GET)); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(BucketEndpointProvider.class); + } + @Test + void crossRegionClient_CallsHeadObjectErrors_shouldTerminateTheAPI() { + mockAsyncHttpClient.stubResponses(customHttpResponse(301, null ), + customHttpResponse(400, null ), + successHttpResponse(), successHttpResponse()); + S3AsyncClient crossRegionClient = + clientBuilder().endpointOverride(null).region(OVERRIDE_CONFIGURED_REGION) + .serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); + + assertThatExceptionOfType(CompletionException.class) + .isThrownBy(() -> crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join()) + .withMessageContaining("software.amazon.awssdk.services.s3.model.S3Exception: null (Service: S3, Status Code: 400, Request ID: null)"); + + List requests = mockAsyncHttpClient.getRequests(); + assertThat(requests).hasSize(2); + + assertThat(requests.stream().map(req -> req.host().substring(10,req.host().length() - 14 )).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(OVERRIDE_CONFIGURED_REGION.toString(), + OVERRIDE_CONFIGURED_REGION.toString())); + + assertThat(requests.stream().map(req -> req.method()).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(SdkHttpMethod.GET, + SdkHttpMethod.HEAD)); + } + + @Test + void crossRegionClient_CallsHeadObjectWithNoRegion_shouldTerminateHeadBucketAPI() { + mockAsyncHttpClient.stubResponses(customHttpResponse(301, null ), + customHttpResponse(301, null ), + successHttpResponse(), successHttpResponse()); + S3AsyncClient crossRegionClient = + clientBuilder().endpointOverride(null).region(OVERRIDE_CONFIGURED_REGION) + .serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); + + assertThatExceptionOfType(CompletionException.class) + .isThrownBy(() -> crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join()) + .withMessageContaining("software.amazon.awssdk.services.s3.model.S3Exception: null (Service: S3, Status Code: 301, Request ID: null)") + .withCauseInstanceOf(S3Exception.class).withRootCauseExactlyInstanceOf(S3Exception.class); + + List requests = mockAsyncHttpClient.getRequests(); + assertThat(requests).hasSize(2); + + assertThat(requests.stream().map(req -> req.host().substring(10,req.host().length() - 14 )).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(OVERRIDE_CONFIGURED_REGION.toString(), + OVERRIDE_CONFIGURED_REGION.toString())); + + assertThat(requests.stream().map(req -> req.method()).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(SdkHttpMethod.GET, + SdkHttpMethod.HEAD)); + } + + + @Test + void crossRegionClient_cancelsTheThread_when_futureIsCancelled(){ + mockAsyncHttpClient.reset(); + mockAsyncHttpClient.stubResponses(customHttpResponse(301, null), + customHttpResponse(301, CROSS_REGION.id()), + successHttpResponse(), successHttpResponse()); + S3AsyncClient crossRegionClient = + clientBuilder().endpointOverride(null).region(OVERRIDE_CONFIGURED_REGION).serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); + CompletableFuture> completableFuture = crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY) + , AsyncResponseTransformer.toBytes()); + + completableFuture.cancel(true); + assertThat(completableFuture.isCancelled()).isTrue(); + } + + @Test + void crossRegionClient_when_redirectsAfterCaching() { + mockAsyncHttpClient.stubResponses(customHttpResponse(301, CROSS_REGION.id()), + successHttpResponse(), + successHttpResponse(), + customHttpResponse(301, CHANGED_CROSS_REGION.id()), + successHttpResponse()); + S3AsyncClient crossRegionClient = + clientBuilder().endpointOverride(null).region(OVERRIDE_CONFIGURED_REGION).serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join(); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join(); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join(); + + assertThat(captureInterceptor.endpointProvider).isInstanceOf(BucketEndpointProvider.class); + + List requests = mockAsyncHttpClient.getRequests(); + assertThat(requests).hasSize(5); + + assertThat(requests.stream().map(req -> req.host().substring(10,req.host().length() - 14 )).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(OVERRIDE_CONFIGURED_REGION.toString(), + CROSS_REGION.toString(), + CROSS_REGION.toString(), + CROSS_REGION.toString(), + CHANGED_CROSS_REGION.toString())); + + assertThat(requests.stream().map(req -> req.method()).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(SdkHttpMethod.GET,SdkHttpMethod.GET,SdkHttpMethod.GET,SdkHttpMethod.GET,SdkHttpMethod.GET)); + } + + @Test + void crossRegionClient_when_redirectsAfterCaching_withFallBackRedirectWithNoRegion() { + mockAsyncHttpClient.stubResponses(customHttpResponse(301, null ), + customHttpResponse(301, CROSS_REGION.id()), + successHttpResponse(), + successHttpResponse(), + customHttpResponse(301, null), + customHttpResponse(301, CHANGED_CROSS_REGION.id()), + successHttpResponse()); + S3AsyncClient crossRegionClient = + clientBuilder().endpointOverride(null).region(OVERRIDE_CONFIGURED_REGION).serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join(); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join(); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join(); + + assertThat(captureInterceptor.endpointProvider).isInstanceOf(BucketEndpointProvider.class); + + List requests = mockAsyncHttpClient.getRequests(); + assertThat(requests).hasSize(7); + + assertThat(requests.stream().map(req -> req.host().substring(10,req.host().length() - 14 )).collect(Collectors.toList())) + .isEqualTo(Arrays.asList( + OVERRIDE_CONFIGURED_REGION.toString(), OVERRIDE_CONFIGURED_REGION.toString(), CROSS_REGION.toString(), + CROSS_REGION.toString(), + CROSS_REGION.toString(), OVERRIDE_CONFIGURED_REGION.toString(), + CHANGED_CROSS_REGION.toString())); + + assertThat(requests.stream().map(req -> req.method()).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(SdkHttpMethod.GET, SdkHttpMethod.HEAD, SdkHttpMethod.GET, + SdkHttpMethod.GET, + SdkHttpMethod.GET, SdkHttpMethod.HEAD, SdkHttpMethod.GET)); + } + + @Test void standardOp_crossRegionClient_containUserAgent() { + mockAsyncHttpClient.stubResponses(successHttpResponse()); S3AsyncClient crossRegionClient = clientBuilder().serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join(); assertThat(mockAsyncHttpClient.getLastRequest().firstMatchingHeader("User-Agent").get()).contains("hll/cross-region"); @@ -107,6 +324,7 @@ void standardOp_crossRegionClient_containUserAgent() { @Test void standardOp_simpleClient_doesNotContainCrossRegionUserAgent() { + mockAsyncHttpClient.stubResponses(successHttpResponse()); S3AsyncClient crossRegionClient = clientBuilder().serviceConfiguration(c -> c.crossRegionAccessEnabled(false)).build(); crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join(); assertThat(mockAsyncHttpClient.getLastRequest().firstMatchingHeader("User-Agent").get()).doesNotContain("hll/cross-region"); @@ -127,4 +345,25 @@ public void beforeMarshalling(Context.BeforeMarshalling context, ExecutionAttrib endpointProvider = executionAttributes.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); } } + + + public static HttpExecuteResponse successHttpResponse() { + return HttpExecuteResponse.builder() + .response(SdkHttpResponse.builder() + .statusCode(200) + .build()) + .responseBody(AbortableInputStream.create(new StringInputStream(RESPONSE))) + .build(); + } + + public static HttpExecuteResponse customHttpResponse(int statusCode, String bucket_region) { + SdkHttpFullResponse.Builder httpResponseBuilder = SdkHttpResponse.builder(); + if (StringUtils.isNotBlank(bucket_region)) { + httpResponseBuilder.appendHeader(X_AMZ_BUCKET_REGION, bucket_region); + } + return HttpExecuteResponse.builder() + .response(httpResponseBuilder.statusCode(statusCode).build()) + .responseBody(AbortableInputStream.create(new StringInputStream(RESPONSE))) + .build(); + } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClientRedirectTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClientRedirectTest.java new file mode 100644 index 000000000000..d295f098354c --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClientRedirectTest.java @@ -0,0 +1,128 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.crossregion; + +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.junit.jupiter.api.BeforeEach; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.HeadBucketRequest; +import software.amazon.awssdk.services.s3.model.HeadBucketResponse; +import software.amazon.awssdk.services.s3.model.ListBucketsRequest; +import software.amazon.awssdk.services.s3.model.ListBucketsResponse; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; + +public class S3CrossRegionSyncClientRedirectTest extends S3DecoratorRedirectTestBase { + + private static S3Client mockDelegateClient; + private S3Client decoratedS3Client; + + @BeforeEach + public void setup() { + mockDelegateClient = Mockito.mock(S3Client.class); + decoratedS3Client = new S3CrossRegionSyncClient(mockDelegateClient); + } + + @Override + protected void verifyNoBucketCall() { + assertThatExceptionOfType(S3Exception.class) + .isThrownBy( + () -> noBucketCallToService()) + .withMessage("Redirect (Service: S3, Status Code: 301, Request ID: 1, " + + "Extended Request ID: A1)"); + } + + @Override + protected void verifyNoBucketApiCall(int times, ArgumentCaptor requestArgumentCaptor) { + verify(mockDelegateClient, times(times)).listBuckets(requestArgumentCaptor.capture()); + } + + @Override + protected ListBucketsResponse noBucketCallToService() { + return decoratedS3Client.listBuckets(ListBucketsRequest.builder().build()); + } + + @Override + protected void stubApiWithNoBucketField() { + when(mockDelegateClient.listBuckets(any(ListBucketsRequest.class))) + .thenThrow(redirectException(301, CROSS_REGION.id(), null, "Redirect")) + .thenReturn(ListBucketsResponse.builder().build()); + } + + @Override + protected void stubHeadBucketRedirect() { + when(mockDelegateClient.headBucket(any(HeadBucketRequest.class))) + .thenThrow(redirectException(301, CROSS_REGION.id(), null, null)) + .thenReturn(HeadBucketResponse.builder().build()); + } + + @Override + protected void stubRedirectWithNoRegionAndThenSuccess() { + when(mockDelegateClient.listObjects(any(ListObjectsRequest.class))) + .thenThrow(redirectException(301, null, null, null)) + .thenReturn(ListObjectsResponse.builder().contents(S3_OBJECTS).build()); + } + + @Override + protected void stubRedirectThenError() { + when(mockDelegateClient.listObjects(any(ListObjectsRequest.class))) + .thenThrow(redirectException(301, CROSS_REGION.id(), null, null)) + .thenThrow(redirectException(400, null, "InvalidArgument", "Invalid id")); + } + + @Override + protected void stubRedirectSuccessSuccess() { + when(mockDelegateClient.listObjects(any(ListObjectsRequest.class))) + .thenThrow(redirectException(301, CROSS_REGION.id(), null, null)) + .thenReturn(ListObjectsResponse.builder().contents(S3_OBJECTS).build()) + .thenReturn(ListObjectsResponse.builder().contents(S3_OBJECTS).build()); + } + + @Override + protected ListObjectsResponse apiCallToService() { + return decoratedS3Client.listObjects(i -> i.bucket(CROSS_REGION_BUCKET)); + } + + @Override + protected void verifyTheApiServiceCall(int times, ArgumentCaptor requestArgumentCaptor) { + verify(mockDelegateClient, times(times)).listObjects(requestArgumentCaptor.capture()); + } + + @Override + protected void verifyHeadBucketServiceCall(int times) { + verify(mockDelegateClient, times(times)).headBucket(any(HeadBucketRequest.class)); + } + + @Override + protected void stubServiceClientConfiguration() { + when(mockDelegateClient.serviceClientConfiguration()).thenReturn(CONFIGURED_ENDPOINT_PROVIDER); + } + + @Override + protected void stubClientAPICallWithFirstRedirectThenSuccessWithRegionInErrorResponse() { + when(mockDelegateClient.listObjects(any(ListObjectsRequest.class))) + .thenThrow(redirectException(301, CROSS_REGION.id(), null, null)) + .thenReturn(ListObjectsResponse.builder().contents(S3_OBJECTS).build()); + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClientTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClientTest.java index 6ec901f788e9..c163d89c3b24 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClientTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClientTest.java @@ -16,60 +16,102 @@ package software.amazon.awssdk.services.s3.internal.crossregion; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.Mockito.verify; +import static software.amazon.awssdk.services.s3.internal.crossregion.S3CrossRegionAsyncClientTest.customHttpResponse; +import static software.amazon.awssdk.services.s3.internal.crossregion.S3CrossRegionAsyncClientTest.successHttpResponse; +import static software.amazon.awssdk.services.s3.internal.crossregion.S3DecoratorRedirectTestBase.CHANGED_CROSS_REGION; +import static software.amazon.awssdk.services.s3.internal.crossregion.S3DecoratorRedirectTestBase.CROSS_REGION; +import static software.amazon.awssdk.services.s3.internal.crossregion.S3DecoratorRedirectTestBase.OVERRIDE_CONFIGURED_REGION; -import java.net.URI; +import java.util.Arrays; +import java.util.List; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.endpoints.EndpointProvider; -import software.amazon.awssdk.http.AbortableInputStream; -import software.amazon.awssdk.http.HttpExecuteResponse; -import software.amazon.awssdk.http.SdkHttpResponse; -import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.http.SdkHttpRequest; +import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.S3ClientBuilder; +import software.amazon.awssdk.services.s3.endpoints.internal.DefaultS3EndpointProvider; +import software.amazon.awssdk.services.s3.internal.crossregion.endpointprovider.BucketEndpointProvider; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.ListObjectsV2Response; +import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable; import software.amazon.awssdk.testutils.service.http.MockSyncHttpClient; -import software.amazon.awssdk.utils.StringInputStream; class S3CrossRegionSyncClientTest { - private static final String RESPONSE = "response"; private static final String BUCKET = "bucket"; private static final String KEY = "key"; private static final String TOKEN = "token"; - private final MockSyncHttpClient mockSyncHttpClient = new MockSyncHttpClient(); + private MockSyncHttpClient mockSyncHttpClient ; private CaptureInterceptor captureInterceptor; private S3Client defaultS3Client; @BeforeEach void before() { - mockSyncHttpClient.stubNextResponse( - HttpExecuteResponse.builder() - .response(SdkHttpResponse.builder().statusCode(200).build()) - .responseBody(AbortableInputStream.create(new StringInputStream(RESPONSE))) - .build()); - + mockSyncHttpClient = new MockSyncHttpClient(); captureInterceptor = new CaptureInterceptor(); defaultS3Client = clientBuilder().build(); } - @Test - void standardOp_crossRegionClient_noOverrideConfig_SuccessfullyIntercepts() { + + private static Stream stubResponses() { + Consumer redirectStubConsumer = mockSyncHttpClient -> + mockSyncHttpClient.stubResponses(customHttpResponse(301, CROSS_REGION.id()), successHttpResponse()); + + Consumer successStubConsumer = mockSyncHttpClient -> + mockSyncHttpClient.stubResponses(successHttpResponse(), successHttpResponse()); + + return Stream.of( + Arguments.of(redirectStubConsumer, BucketEndpointProvider.class), + Arguments.of(successStubConsumer, DefaultS3EndpointProvider.class) + ); + } + + private static Stream stubOverriddenEndpointProviderResponses() { + Consumer redirectStubConsumer = mockSyncHttpClient -> + mockSyncHttpClient.stubResponses(customHttpResponse(301, CROSS_REGION.id()), successHttpResponse()); + + Consumer successStubConsumer = mockSyncHttpClient -> + mockSyncHttpClient.stubResponses(successHttpResponse(), successHttpResponse()); + + return Stream.of( + Arguments.of(redirectStubConsumer, BucketEndpointProvider.class, CROSS_REGION), + Arguments.of(successStubConsumer, TestEndpointProvider.class, OVERRIDE_CONFIGURED_REGION) + ); + } + + @ParameterizedTest + @MethodSource("stubResponses") + void standardOp_crossRegionClient_noOverrideConfig_SuccessfullyIntercepts(Consumer stubConsumer, + Class endpointProviderType) { + stubConsumer.accept(mockSyncHttpClient); S3Client crossRegionClient = new S3CrossRegionSyncClient(defaultS3Client); crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY)); - assertThat(captureInterceptor.endpointProvider).isInstanceOf(S3CrossRegionSyncClient.BucketEndpointProvider.class); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(endpointProviderType); } - @Test - void standardOp_crossRegionClient_existingOverrideConfig_SuccessfullyIntercepts() { + @ParameterizedTest + @MethodSource("stubResponses") + void standardOp_crossRegionClient_existingOverrideConfig_SuccessfullyIntercepts(Consumer stubConsumer, + Class endpointProviderType) { + stubConsumer.accept(mockSyncHttpClient); S3Client crossRegionClient = new S3CrossRegionSyncClient(defaultS3Client); GetObjectRequest request = GetObjectRequest.builder() .bucket(BUCKET) @@ -77,28 +119,228 @@ void standardOp_crossRegionClient_existingOverrideConfig_SuccessfullyIntercepts( .overrideConfiguration(o -> o.putHeader("someheader", "somevalue")) .build(); crossRegionClient.getObject(request); - assertThat(captureInterceptor.endpointProvider).isInstanceOf(S3CrossRegionSyncClient.BucketEndpointProvider.class); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(endpointProviderType); assertThat(mockSyncHttpClient.getLastRequest().headers().get("someheader")).isNotNull(); } - @Test - void paginatedOp_crossRegionClient_DoesNotIntercept() { + @ParameterizedTest + @MethodSource("stubResponses") + void paginatedOp_crossRegionClient_DoesNotIntercept(Consumer stubConsumer, + Class endpointProviderType) { + stubConsumer.accept(mockSyncHttpClient); S3Client crossRegionClient = new S3CrossRegionSyncClient(defaultS3Client); ListObjectsV2Iterable iterable = crossRegionClient.listObjectsV2Paginator(r -> r.bucket(BUCKET).continuationToken(TOKEN).build()); iterable.forEach(ListObjectsV2Response::contents); - assertThat(captureInterceptor.endpointProvider).isInstanceOf(S3CrossRegionSyncClient.BucketEndpointProvider.class); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(endpointProviderType); } - @Test - void crossRegionClient_createdWithWrapping_SuccessfullyIntercepts() { + @ParameterizedTest + @MethodSource("stubResponses") + void crossRegionClient_createdWithWrapping_SuccessfullyIntercepts(Consumer stubConsumer, + Class endpointProviderType) { + stubConsumer.accept(mockSyncHttpClient); S3Client crossRegionClient = clientBuilder().serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY)); - assertThat(captureInterceptor.endpointProvider).isInstanceOf(S3CrossRegionSyncClient.BucketEndpointProvider.class); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(endpointProviderType); + } + + @ParameterizedTest + @MethodSource("stubOverriddenEndpointProviderResponses") + void standardOp_crossRegionClient_takesCustomEndpointProviderInRequest(Consumer stubConsumer, + Class endpointProviderType, + Region region) { + stubConsumer.accept(mockSyncHttpClient); + S3Client crossRegionClient = clientBuilder().serviceConfiguration(c -> c.crossRegionAccessEnabled(true)) + .endpointProvider(new TestEndpointProvider()) + .region(OVERRIDE_CONFIGURED_REGION) + .build(); + GetObjectRequest request = GetObjectRequest.builder() + .bucket(BUCKET) + .key(KEY) + .overrideConfiguration(o -> o.putHeader("someheader", "somevalue") + .endpointProvider(new TestEndpointProvider())) + .build(); + crossRegionClient.getObject(request); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(endpointProviderType); + assertThat(mockSyncHttpClient.getLastRequest().headers().get("someheader")).isNotNull(); + assertThat(mockSyncHttpClient.getLastRequest().encodedPath()).contains("test_prefix_"); + assertThat(mockSyncHttpClient.getLastRequest().host()).contains(region.id()); } + @ParameterizedTest + @MethodSource("stubOverriddenEndpointProviderResponses") + void standardOp_crossRegionClient_takesCustomEndpointProviderInClient(Consumer stubConsumer, + Class endpointProviderType, + Region region) { + stubConsumer.accept(mockSyncHttpClient); + S3Client crossRegionClient = clientBuilder().serviceConfiguration(c -> c.crossRegionAccessEnabled(true)) + .endpointProvider(new TestEndpointProvider()) + .region(OVERRIDE_CONFIGURED_REGION) + .build(); + GetObjectRequest request = GetObjectRequest.builder() + .bucket(BUCKET) + .key(KEY) + .overrideConfiguration(o -> o.putHeader("someheader", "somevalue")) + .build(); + crossRegionClient.getObject(request); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(endpointProviderType); + assertThat(mockSyncHttpClient.getLastRequest().headers().get("someheader")).isNotNull(); + assertThat(mockSyncHttpClient.getLastRequest().encodedPath()).contains("test_prefix_"); + assertThat(mockSyncHttpClient.getLastRequest().host()).contains(region.id()); + } + + @Test + void crossRegionClient_CallsHeadObject_when_regionNameNotPresentInFallBackCall() { + mockSyncHttpClient.stubResponses(customHttpResponse(301, null ), + customHttpResponse(301, CROSS_REGION.id() ), + successHttpResponse(), successHttpResponse()); + S3Client crossRegionClient = + clientBuilder().endpointOverride(null).region(OVERRIDE_CONFIGURED_REGION).serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY)); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(BucketEndpointProvider.class); + + List requests = mockSyncHttpClient.getRequests(); + assertThat(requests).hasSize(3); + + assertThat(requests.stream().map(req -> req.host().substring(10,req.host().length() - 14 )).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(OVERRIDE_CONFIGURED_REGION.toString(), + OVERRIDE_CONFIGURED_REGION.toString(), + CROSS_REGION.id())); + + assertThat(requests.stream().map(req -> req.method()).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(SdkHttpMethod.GET, + SdkHttpMethod.HEAD, + SdkHttpMethod.GET)); + + // Resetting the mock client to capture the new API request for second S3 Call. + mockSyncHttpClient.reset(); + mockSyncHttpClient.stubResponses(successHttpResponse(), successHttpResponse()); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY)); + List postCacheRequests = mockSyncHttpClient.getRequests(); + + assertThat(postCacheRequests.stream() + .map(req -> req.host().substring(10,req.host().length() - 14 )) + .collect(Collectors.toList())) + .isEqualTo(Arrays.asList(CROSS_REGION.id())); + assertThat(postCacheRequests.stream().map(req -> req.method()).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(SdkHttpMethod.GET)); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(BucketEndpointProvider.class); + + } + + @Test + void crossRegionClient_when_redirectsAfterCaching() { + mockSyncHttpClient.stubResponses(customHttpResponse(301, CROSS_REGION.id()), + successHttpResponse(), + successHttpResponse(), + customHttpResponse(301, CHANGED_CROSS_REGION.id()), + successHttpResponse()); + S3Client crossRegionClient = + clientBuilder().endpointOverride(null).region(OVERRIDE_CONFIGURED_REGION).serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY)); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY)); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY)); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(BucketEndpointProvider.class); + + List requests = mockSyncHttpClient.getRequests(); + assertThat(requests).hasSize(5); + + assertThat(requests.stream().map(req -> req.host().substring(10,req.host().length() - 14 )).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(OVERRIDE_CONFIGURED_REGION.toString(), + CROSS_REGION.toString(), + CROSS_REGION.toString(), + CROSS_REGION.toString(), + CHANGED_CROSS_REGION.toString())); + + assertThat(requests.stream().map(req -> req.method()).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(SdkHttpMethod.GET,SdkHttpMethod.GET,SdkHttpMethod.GET,SdkHttpMethod.GET,SdkHttpMethod.GET)); + } + + @Test + void crossRegionClient_when_redirectsAfterCaching_withFallBackRedirectWithNoRegion() { + mockSyncHttpClient.stubResponses(customHttpResponse(301, null ), + customHttpResponse(301, CROSS_REGION.id() ), + successHttpResponse(), + successHttpResponse(), + customHttpResponse(301, null), + customHttpResponse(301, CHANGED_CROSS_REGION.id()), + successHttpResponse()); + S3Client crossRegionClient = + clientBuilder().endpointOverride(null).region(OVERRIDE_CONFIGURED_REGION).serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY)); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY)); + crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY)); + assertThat(captureInterceptor.endpointProvider).isInstanceOf(BucketEndpointProvider.class); + + List requests = mockSyncHttpClient.getRequests(); + assertThat(requests).hasSize(7); + + assertThat(requests.stream().map(req -> req.host().substring(10,req.host().length() - 14 )).collect(Collectors.toList())) + .isEqualTo(Arrays.asList( + OVERRIDE_CONFIGURED_REGION.toString(), OVERRIDE_CONFIGURED_REGION.toString(), CROSS_REGION.toString(), + CROSS_REGION.toString(), + CROSS_REGION.toString(), OVERRIDE_CONFIGURED_REGION.toString(), + CHANGED_CROSS_REGION.toString())); + + assertThat(requests.stream().map(req -> req.method()).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(SdkHttpMethod.GET, SdkHttpMethod.HEAD, SdkHttpMethod.GET, + SdkHttpMethod.GET, + SdkHttpMethod.GET, SdkHttpMethod.HEAD, SdkHttpMethod.GET)); + } + + @Test + void crossRegionClient_CallsHeadObjectErrors_shouldTerminateTheAPI() { + mockSyncHttpClient.stubResponses(customHttpResponse(301, null ), + customHttpResponse(400, null ), + successHttpResponse(), successHttpResponse()); + S3Client crossRegionClient = + clientBuilder().endpointOverride(null).region(OVERRIDE_CONFIGURED_REGION).serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); + + assertThatExceptionOfType(S3Exception.class) + .isThrownBy(() -> crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY))) + .withMessageContaining("Status Code: 400"); + + List requests = mockSyncHttpClient.getRequests(); + assertThat(requests).hasSize(2); + + assertThat(requests.stream().map(req -> req.host().substring(10,req.host().length() - 14 )).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(OVERRIDE_CONFIGURED_REGION.toString(), + OVERRIDE_CONFIGURED_REGION.toString())); + + assertThat(requests.stream().map(req -> req.method()).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(SdkHttpMethod.GET, + SdkHttpMethod.HEAD)); + } + + @Test + void crossRegionClient_CallsHeadObjectWithNoRegion_shouldTerminateHeadBucketAPI() { + mockSyncHttpClient.stubResponses(customHttpResponse(301, null ), + customHttpResponse(301, null ), + successHttpResponse(), successHttpResponse()); + S3Client crossRegionClient = + clientBuilder().endpointOverride(null).region(OVERRIDE_CONFIGURED_REGION).serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); + + assertThatExceptionOfType(S3Exception.class) + .isThrownBy(() -> crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY))) + .withMessageContaining("Status Code: 301"); + + List requests = mockSyncHttpClient.getRequests(); + assertThat(requests).hasSize(2); + + assertThat(requests.stream().map(req -> req.host().substring(10,req.host().length() - 14 )).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(OVERRIDE_CONFIGURED_REGION.toString(), + OVERRIDE_CONFIGURED_REGION.toString())); + + assertThat(requests.stream().map(req -> req.method()).collect(Collectors.toList())) + .isEqualTo(Arrays.asList(SdkHttpMethod.GET, + SdkHttpMethod.HEAD)); + } + + @Test void standardOp_crossRegionClient_containUserAgent() { + mockSyncHttpClient.stubResponses(successHttpResponse()); S3Client crossRegionClient = clientBuilder().serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build(); crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY)); assertThat(mockSyncHttpClient.getLastRequest().firstMatchingHeader("User-Agent").get()).contains("hll/cross-region"); @@ -106,6 +348,7 @@ void standardOp_crossRegionClient_containUserAgent() { @Test void standardOp_simpleClient_doesNotContainCrossRegionUserAgent() { + mockSyncHttpClient.stubResponses(successHttpResponse()); S3Client crossRegionClient = clientBuilder().serviceConfiguration(c -> c.crossRegionAccessEnabled(false)).build(); crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY)); assertThat(mockSyncHttpClient.getLastRequest().firstMatchingHeader("User-Agent").get()) @@ -115,7 +358,6 @@ void standardOp_simpleClient_doesNotContainCrossRegionUserAgent() { private S3ClientBuilder clientBuilder() { return S3Client.builder() .httpClient(mockSyncHttpClient) - .endpointOverride(URI.create("http://localhost")) .overrideConfiguration(c -> c.addExecutionInterceptor(captureInterceptor)); } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3DecoratorRedirectTestBase.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3DecoratorRedirectTestBase.java new file mode 100644 index 000000000000..10d0709b377f --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3DecoratorRedirectTestBase.java @@ -0,0 +1,214 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.crossregion; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +import java.util.Collections; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.awscore.exception.AwsErrorDetails; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.endpoints.EndpointProvider; +import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3ServiceClientConfiguration; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; +import software.amazon.awssdk.services.s3.internal.crossregion.endpointprovider.BucketEndpointProvider; +import software.amazon.awssdk.services.s3.model.ListBucketsRequest; +import software.amazon.awssdk.services.s3.model.ListBucketsResponse; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.services.s3.model.S3Object; + +public abstract class S3DecoratorRedirectTestBase { + + public static final String X_AMZ_BUCKET_REGION = "x-amz-bucket-region"; + protected static final String CROSS_REGION_BUCKET = "anyBucket"; + protected static final Region CROSS_REGION = Region.EU_CENTRAL_1; + protected static final Region CHANGED_CROSS_REGION = Region.US_WEST_1; + + public static final Region OVERRIDE_CONFIGURED_REGION = Region.US_WEST_2; + + protected static final List S3_OBJECTS = Collections.singletonList(S3Object.builder().key("keyObject").build()); + + protected static final S3ServiceClientConfiguration CONFIGURED_ENDPOINT_PROVIDER = + S3ServiceClientConfiguration.builder().endpointProvider(S3EndpointProvider.defaultProvider()).build(); + + @Test + void decoratorAttemptsToRetryWithRegionNameInErrorResponse() throws Throwable { + stubServiceClientConfiguration(); + stubClientAPICallWithFirstRedirectThenSuccessWithRegionInErrorResponse(); + // Assert retrieved listObject + ListObjectsResponse listObjectsResponse = apiCallToService(); + assertThat(listObjectsResponse.contents()).isEqualTo(S3_OBJECTS); + + ArgumentCaptor requestArgumentCaptor = ArgumentCaptor.forClass(ListObjectsRequest.class); + verifyTheApiServiceCall(2, requestArgumentCaptor); + + assertThat(requestArgumentCaptor.getAllValues().get(0).overrideConfiguration().get().endpointProvider()).isNotPresent(); + verifyTheEndPointProviderOverridden(1, requestArgumentCaptor, CROSS_REGION.id()); + + verifyHeadBucketServiceCall(0); + } + + @Test + void decoratorUsesCache_when_CrossRegionAlreadyPresent() throws Throwable { + stubServiceClientConfiguration(); + stubRedirectSuccessSuccess(); + + ListObjectsResponse listObjectsResponse = apiCallToService(); + assertThat(listObjectsResponse.contents()).isEqualTo(S3_OBJECTS); + + ListObjectsResponse listObjectsResponseSecondCall = apiCallToService(); + assertThat(listObjectsResponseSecondCall.contents()).isEqualTo(S3_OBJECTS); + + ArgumentCaptor requestArgumentCaptor = ArgumentCaptor.forClass(ListObjectsRequest.class); + verifyTheApiServiceCall(3, requestArgumentCaptor); + + assertThat(requestArgumentCaptor.getAllValues().get(0).overrideConfiguration().get().endpointProvider()).isNotPresent(); + verifyTheEndPointProviderOverridden(1, requestArgumentCaptor, CROSS_REGION.id()); + verifyTheEndPointProviderOverridden(2, requestArgumentCaptor, CROSS_REGION.id()); + verifyHeadBucketServiceCall(0); + } + + /** + * Call is redirected to actual end point + * The redirected call fails because of incorrect parameters passed + * This exception should be reported correctly + */ + @Test + void apiCallFailure_when_CallFailsAfterRedirection() { + stubServiceClientConfiguration(); + stubRedirectThenError(); + assertThatExceptionOfType(S3Exception.class) + .isThrownBy(() -> apiCallToService()) + .withMessageContaining("Invalid id (Service: S3, Status Code: 400, Request ID: 1, Extended Request ID: A1)"); + verifyHeadBucketServiceCall(0); + } + + @Test + void headBucketCalled_when_RedirectDoesNotHasRegionName() throws Throwable { + stubServiceClientConfiguration(); + stubRedirectWithNoRegionAndThenSuccess(); + stubHeadBucketRedirect(); + ListObjectsResponse listObjectsResponse = apiCallToService(); + assertThat(listObjectsResponse.contents()).isEqualTo(S3_OBJECTS); + + ArgumentCaptor requestArgumentCaptor = ArgumentCaptor.forClass(ListObjectsRequest.class); + verifyTheApiServiceCall(2, requestArgumentCaptor); + + assertThat(requestArgumentCaptor.getAllValues().get(0).overrideConfiguration().get().endpointProvider()).isNotPresent(); + verifyTheEndPointProviderOverridden(1, requestArgumentCaptor, CROSS_REGION.id()); + verifyHeadBucketServiceCall(1); + } + + @Test + void headBucketCalledAndCached__when_RedirectDoesNotHasRegionName() throws Throwable { + stubServiceClientConfiguration(); + stubRedirectWithNoRegionAndThenSuccess(); + stubHeadBucketRedirect(); + ListObjectsResponse listObjectsResponse = apiCallToService(); + assertThat(listObjectsResponse.contents()).isEqualTo(S3_OBJECTS); + + ArgumentCaptor preCacheCaptor = ArgumentCaptor.forClass(ListObjectsRequest.class); + verifyTheApiServiceCall(2, preCacheCaptor); + // We need to get the BucketEndpointProvider in order to update the cache + verifyTheEndPointProviderOverridden(1, preCacheCaptor, CROSS_REGION.id()); + listObjectsResponse = apiCallToService(); + assertThat(listObjectsResponse.contents()).isEqualTo(S3_OBJECTS); + // We need to captor again so that we get the args used in second API Call + ArgumentCaptor overAllPostCacheCaptor = ArgumentCaptor.forClass(ListObjectsRequest.class); + verifyTheApiServiceCall(3, overAllPostCacheCaptor); + assertThat(overAllPostCacheCaptor.getAllValues().get(0).overrideConfiguration().get().endpointProvider()).isNotPresent(); + verifyTheEndPointProviderOverridden(1, overAllPostCacheCaptor, CROSS_REGION.id()); + verifyTheEndPointProviderOverridden(2, overAllPostCacheCaptor, CROSS_REGION.id()); + verifyHeadBucketServiceCall(1); + } + + @Test + void requestsAreNotOverridden_when_NoBucketInRequest() throws Throwable { + stubServiceClientConfiguration(); + stubApiWithNoBucketField(); + stubHeadBucketRedirect(); + verifyNoBucketCall(); + ArgumentCaptor requestArgumentCaptor = ArgumentCaptor.forClass(ListBucketsRequest.class); + verifyHeadBucketServiceCall(0); + verifyNoBucketApiCall(1, requestArgumentCaptor); + assertThat(requestArgumentCaptor.getAllValues().get(0).overrideConfiguration().get().endpointProvider()).isNotPresent(); + verifyHeadBucketServiceCall(0); + } + + protected abstract void verifyNoBucketCall(); + + protected abstract void verifyNoBucketApiCall(int i, ArgumentCaptor requestArgumentCaptor); + + protected abstract ListBucketsResponse noBucketCallToService() throws Throwable; + + protected abstract void stubApiWithNoBucketField(); + + protected abstract void stubHeadBucketRedirect(); + + protected abstract void stubRedirectWithNoRegionAndThenSuccess(); + + protected abstract void stubRedirectThenError(); + + protected abstract void stubRedirectSuccessSuccess(); + + protected AwsServiceException redirectException(int statusCode, String region, String errorCode, String errorMessage) { + SdkHttpFullResponse.Builder sdkHttpFullResponseBuilder = SdkHttpFullResponse.builder(); + if (region != null) { + sdkHttpFullResponseBuilder.appendHeader(X_AMZ_BUCKET_REGION, region); + } + return S3Exception.builder() + .statusCode(statusCode) + .requestId("1") + .extendedRequestId("A1") + .awsErrorDetails(AwsErrorDetails.builder() + .errorMessage(errorMessage) + .sdkHttpResponse(sdkHttpFullResponseBuilder.build()) + .errorCode(errorCode) + .serviceName("S3") + .build()) + .build(); + } + + void verifyTheEndPointProviderOverridden(int attempt, + ArgumentCaptor requestArgumentCaptor, + String expectedRegion) throws Exception { + EndpointProvider overridenEndpointProvider = + requestArgumentCaptor.getAllValues().get(attempt).overrideConfiguration().get().endpointProvider().get(); + assertThat(overridenEndpointProvider).isInstanceOf(BucketEndpointProvider.class); + assertThat(((S3EndpointProvider) overridenEndpointProvider).resolveEndpoint(e -> e.region(Region.US_WEST_2) + .bucket(CROSS_REGION_BUCKET) + .build()) + .get().url().getHost()) + .isEqualTo("s3." + expectedRegion + ".amazonaws.com"); + } + + protected abstract ListObjectsResponse apiCallToService() throws Throwable; + + protected abstract void verifyTheApiServiceCall(int times, ArgumentCaptor requestArgumentCaptor); + + protected abstract void verifyHeadBucketServiceCall(int times); + + protected abstract void stubServiceClientConfiguration(); + + protected abstract void stubClientAPICallWithFirstRedirectThenSuccessWithRegionInErrorResponse(); +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/TestEndpointProvider.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/TestEndpointProvider.java new file mode 100644 index 000000000000..e7ab65e6fdde --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/TestEndpointProvider.java @@ -0,0 +1,31 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.crossregion; + + +import java.util.concurrent.CompletableFuture; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; + +public class TestEndpointProvider implements S3EndpointProvider { + S3EndpointProvider s3EndpointProvider = S3EndpointProvider.defaultProvider(); + @Override + public CompletableFuture resolveEndpoint(S3EndpointParams endpointParams) { + return s3EndpointProvider.resolveEndpoint(endpointParams.copy(c -> c.bucket("test_prefix_"+endpointParams.bucket()))); + + } +} \ No newline at end of file