diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/AsyncBufferingSubscriber.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/AsyncBufferingSubscriber.java index f4394331563f..4245ff5591f3 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/AsyncBufferingSubscriber.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/AsyncBufferingSubscriber.java @@ -15,9 +15,7 @@ package software.amazon.awssdk.transfer.s3.internal; -import java.util.Optional; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import org.reactivestreams.Subscriber; @@ -25,8 +23,6 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.Validate; -import software.amazon.awssdk.utils.async.DemandIgnoringSubscription; -import software.amazon.awssdk.utils.async.StoringSubscriber; /** * An implementation of {@link Subscriber} that execute the provided function for every event and limits the number of concurrent @@ -41,12 +37,9 @@ public class AsyncBufferingSubscriber implements Subscriber { private final Function> consumer; private final int maxConcurrentExecutions; private final AtomicInteger numRequestsInFlight; - private final AtomicBoolean isDelivering = new AtomicBoolean(false); - private volatile boolean isStreamingDone; + private volatile boolean upstreamDone; private Subscription subscription; - private final StoringSubscriber storingSubscriber; - public AsyncBufferingSubscriber(Function> consumer, CompletableFuture returnFuture, int maxConcurrentExecutions) { @@ -54,7 +47,6 @@ public AsyncBufferingSubscriber(Function> consumer, this.consumer = consumer; this.maxConcurrentExecutions = maxConcurrentExecutions; this.numRequestsInFlight = new AtomicInteger(0); - this.storingSubscriber = new StoringSubscriber<>(Integer.MAX_VALUE); } @Override @@ -65,89 +57,41 @@ public void onSubscribe(Subscription subscription) { subscription.cancel(); return; } - storingSubscriber.onSubscribe(new DemandIgnoringSubscription(subscription)); this.subscription = subscription; subscription.request(maxConcurrentExecutions); } @Override public void onNext(T item) { - storingSubscriber.onNext(item); - flushBufferIfNeeded(); - } - - private void flushBufferIfNeeded() { - if (isDelivering.compareAndSet(false, true)) { - try { - Optional> next = storingSubscriber.peek(); - while (numRequestsInFlight.get() < maxConcurrentExecutions) { - if (!next.isPresent()) { - subscription.request(1); - break; - } - - switch (next.get().type()) { - case ON_COMPLETE: - handleCompleteEvent(); - break; - case ON_ERROR: - handleError(next.get().runtimeError()); - break; - case ON_NEXT: - handleOnNext(next.get().value()); - break; - default: - handleError(new IllegalStateException("Unknown stored type: " + next.get().type())); - break; - } - - next = storingSubscriber.peek(); - } - } finally { - isDelivering.set(false); - } - } - } - - private void handleOnNext(T item) { - storingSubscriber.poll(); - - int numberOfRequestInFlight = numRequestsInFlight.incrementAndGet(); - log.debug(() -> "Delivering next item, numRequestInFlight=" + numberOfRequestInFlight); - + numRequestsInFlight.incrementAndGet(); consumer.apply(item).whenComplete((r, t) -> { - numRequestsInFlight.decrementAndGet(); - if (!isStreamingDone) { + checkForCompletion(numRequestsInFlight.decrementAndGet()); + synchronized (this) { subscription.request(1); - } else { - flushBufferIfNeeded(); } }); } - private void handleCompleteEvent() { - if (numRequestsInFlight.get() == 0) { - returnFuture.complete(null); - storingSubscriber.poll(); - } - } - @Override public void onError(Throwable t) { - handleError(t); - storingSubscriber.onError(t); - } - - private void handleError(Throwable t) { + // Need to complete future exceptionally first to prevent + // accidental successful completion by a concurrent checkForCompletion. returnFuture.completeExceptionally(t); - storingSubscriber.poll(); + upstreamDone = true; } @Override public void onComplete() { - isStreamingDone = true; - storingSubscriber.onComplete(); - flushBufferIfNeeded(); + upstreamDone = true; + checkForCompletion(numRequestsInFlight.get()); + } + + private void checkForCompletion(int requestsInFlight) { + if (upstreamDone && requestsInFlight == 0) { + // This could get invoked multiple times, but it doesn't matter + // because future.complete is idempotent. + returnFuture.complete(null); + } } /** diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/TransferManagerLoggingTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/TransferManagerLoggingTest.java index a1998973c7ea..4f2be4d00063 100644 --- a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/TransferManagerLoggingTest.java +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/TransferManagerLoggingTest.java @@ -17,42 +17,52 @@ import static org.assertj.core.api.Assertions.assertThat; +import java.util.HashSet; import java.util.List; +import java.util.Set; import org.apache.logging.log4j.Level; import org.apache.logging.log4j.core.LogEvent; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.testutils.LogCaptor; import software.amazon.awssdk.transfer.s3.S3TransferManager; -public class TransferManagerLoggingTest { +class TransferManagerLoggingTest { @Test - public void transferManager_withCrtClient_shouldNotLogWarnMessages(){ - LogCaptor logCaptor = LogCaptor.create(Level.WARN); - S3AsyncClient s3Crt = S3AsyncClient.crtCreate(); - S3TransferManager tm = S3TransferManager.builder().s3Client(s3Crt).build(); + void transferManager_withCrtClient_shouldNotLogWarnMessages() { - List events = logCaptor.loggedEvents(); - assertThat(events).isEmpty(); - logCaptor.clear(); - logCaptor.close(); + try (S3AsyncClient s3Crt = S3AsyncClient.crtBuilder() + .region(Region.US_WEST_2) + .credentialsProvider(() -> AwsBasicCredentials.create("foo", "bar")) + .build(); + LogCaptor logCaptor = LogCaptor.create(Level.WARN); + S3TransferManager tm = S3TransferManager.builder().s3Client(s3Crt).build()) { + List events = logCaptor.loggedEvents(); + assertThat(events).isEmpty(); + } } @Test - public void transferManager_withJavaClient_shouldLogWarnMessage(){ - LogCaptor logCaptor = LogCaptor.create(Level.WARN); - S3AsyncClient s3Java = S3AsyncClient.create(); - S3TransferManager tm = S3TransferManager.builder().s3Client(s3Java).build(); + void transferManager_withJavaClient_shouldLogWarnMessage() { - List events = logCaptor.loggedEvents(); - assertLogged(events, Level.WARN, "The provided DefaultS3AsyncClient is not an instance of S3CrtAsyncClient, and " - + "thus multipart upload/download feature is not enabled and resumable file upload is " - + "not supported. To benefit from maximum throughput, consider using " - + "S3AsyncClient.crtBuilder().build() instead."); - logCaptor.clear(); - logCaptor.close(); + + try (S3AsyncClient s3Crt = S3AsyncClient.builder() + .region(Region.US_WEST_2) + .credentialsProvider(() -> AwsBasicCredentials.create("foo", "bar")) + .build(); + LogCaptor logCaptor = LogCaptor.create(Level.WARN); + S3TransferManager tm = S3TransferManager.builder().s3Client(s3Crt).build()) { + List events = logCaptor.loggedEvents(); + assertLogged(events, Level.WARN, "The provided DefaultS3AsyncClient is not an instance of S3CrtAsyncClient, and " + + "thus multipart upload/download feature is not enabled and resumable file upload" + + " is " + + "not supported. To benefit from maximum throughput, consider using " + + "S3AsyncClient.crtBuilder().build() instead."); + } } private static void assertLogged(List events, org.apache.logging.log4j.Level level, String message) {