diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/InputStreamResponseTransformer.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/InputStreamResponseTransformer.java index 084a293f6344..434894a44c8c 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/InputStreamResponseTransformer.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/InputStreamResponseTransformer.java @@ -22,7 +22,7 @@ import software.amazon.awssdk.core.SdkResponse; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.SdkPublisher; -import software.amazon.awssdk.utils.async.InputStreamSubscriber; +import software.amazon.awssdk.http.async.AbortableInputStreamSubscriber; /** * A {@link AsyncResponseTransformer} that allows performing blocking reads on the response data. @@ -50,7 +50,7 @@ public void onResponse(ResponseT response) { @Override public void onStream(SdkPublisher publisher) { - InputStreamSubscriber inputStreamSubscriber = new InputStreamSubscriber(); + AbortableInputStreamSubscriber inputStreamSubscriber = AbortableInputStreamSubscriber.builder().build(); publisher.subscribe(inputStreamSubscriber); future.complete(new ResponseInputStream<>(response, inputStreamSubscriber)); } diff --git a/http-client-spi/pom.xml b/http-client-spi/pom.xml index 1c7bd33ed322..a34067f02554 100644 --- a/http-client-spi/pom.xml +++ b/http-client-spi/pom.xml @@ -90,6 +90,16 @@ byte-buddy test + + org.mockito + mockito-junit-jupiter + test + + + org.mockito + mockito-inline + test + diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/AbortableInputStreamSubscriber.java b/http-client-spi/src/main/java/software/amazon/awssdk/http/async/AbortableInputStreamSubscriber.java similarity index 56% rename from http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/AbortableInputStreamSubscriber.java rename to http-client-spi/src/main/java/software/amazon/awssdk/http/async/AbortableInputStreamSubscriber.java index c6c2ad6151b0..630869825700 100644 --- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/AbortableInputStreamSubscriber.java +++ b/http-client-spi/src/main/java/software/amazon/awssdk/http/async/AbortableInputStreamSubscriber.java @@ -13,30 +13,43 @@ * permissions and limitations under the License. */ -package software.amazon.awssdk.http.crt.internal.response; +package software.amazon.awssdk.http.async; import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; -import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.annotations.SdkTestInternalApi; import software.amazon.awssdk.http.Abortable; +import software.amazon.awssdk.utils.FunctionalUtils; import software.amazon.awssdk.utils.async.InputStreamSubscriber; /** - * Wrapper of {@link InputStreamSubscriber} that also implements {@link Abortable} and closes the underlying connections when - * {@link #close()} or {@link #abort()} is invoked. + * Wrapper of {@link InputStreamSubscriber} that also implements {@link Abortable}. It will invoke {@link #close()} + * when {@link #abort()} is invoked. Upon closing, the underlying {@link InputStreamSubscriber} will be closed, and additional + * action can be added via {@link Builder#doAfterClose(Runnable)}. + * */ -@SdkInternalApi +@SdkProtectedApi public final class AbortableInputStreamSubscriber extends InputStream implements Subscriber, Abortable { - private final InputStreamSubscriber delegate; - private final Runnable closeConnection; - public AbortableInputStreamSubscriber(Runnable onClose, InputStreamSubscriber inputStreamSubscriber) { - this.delegate = inputStreamSubscriber; - this.closeConnection = onClose; + private final Runnable doAfterClose; + + private AbortableInputStreamSubscriber(Builder builder) { + this(builder, new InputStreamSubscriber()); + } + + @SdkTestInternalApi + AbortableInputStreamSubscriber(Builder builder, InputStreamSubscriber delegate) { + this.delegate = delegate; + this.doAfterClose = builder.doAfterClose == null ? FunctionalUtils.noOpRunnable() : builder.doAfterClose; + } + + public static Builder builder() { + return new Builder(); } @Override @@ -81,7 +94,23 @@ public void onComplete() { @Override public void close() { - closeConnection.run(); delegate.close(); + FunctionalUtils.invokeSafely(() -> doAfterClose.run()); + } + + public static final class Builder { + private Runnable doAfterClose; + + /** + * Additional action to run when {@link #close()} is invoked + */ + public Builder doAfterClose(Runnable doAfterClose) { + this.doAfterClose = doAfterClose; + return this; + } + + public AbortableInputStreamSubscriber build() { + return new AbortableInputStreamSubscriber(this); + } } } diff --git a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/AbortableInputStreamSubscriberTest.java b/http-client-spi/src/test/java/software/amazon/awssdk/http/async/AbortableInputStreamSubscriberTest.java similarity index 53% rename from http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/AbortableInputStreamSubscriberTest.java rename to http-client-spi/src/test/java/software/amazon/awssdk/http/async/AbortableInputStreamSubscriberTest.java index ca9ea61cecb2..dc3e49ff7205 100644 --- a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/AbortableInputStreamSubscriberTest.java +++ b/http-client-spi/src/test/java/software/amazon/awssdk/http/async/AbortableInputStreamSubscriberTest.java @@ -13,7 +13,7 @@ * permissions and limitations under the License. */ -package software.amazon.awssdk.http.crt.internal; +package software.amazon.awssdk.http.async; import static org.mockito.Mockito.verify; @@ -22,7 +22,6 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import software.amazon.awssdk.http.crt.internal.response.AbortableInputStreamSubscriber; import software.amazon.awssdk.utils.async.InputStreamSubscriber; @ExtendWith(MockitoExtension.class) @@ -33,20 +32,39 @@ public class AbortableInputStreamSubscriberTest { @Mock private Runnable onClose; + @Mock + private InputStreamSubscriber inputStreamSubscriber; + @BeforeEach void setUp() { - abortableInputStreamSubscriber = new AbortableInputStreamSubscriber(onClose, new InputStreamSubscriber()); + abortableInputStreamSubscriber = new AbortableInputStreamSubscriber(AbortableInputStreamSubscriber.builder() + .doAfterClose(onClose), + inputStreamSubscriber); + + } @Test - void close_shouldInvokeOnClose() { + void close_closeConfigured_shouldInvokeOnClose() { abortableInputStreamSubscriber.close(); + verify(inputStreamSubscriber).close(); verify(onClose).run(); } @Test void abort_shouldInvokeOnClose() { + abortableInputStreamSubscriber = new AbortableInputStreamSubscriber(AbortableInputStreamSubscriber.builder() + .doAfterClose(onClose), + inputStreamSubscriber); abortableInputStreamSubscriber.abort(); verify(onClose).run(); } + + @Test + void close_closeNotConfigured_shouldCloseDelegate() { + abortableInputStreamSubscriber = new AbortableInputStreamSubscriber(AbortableInputStreamSubscriber.builder(), + inputStreamSubscriber); + abortableInputStreamSubscriber.close(); + verify(inputStreamSubscriber).close(); + } } diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java index b6b95307722e..66568efc2b6f 100644 --- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java +++ b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java @@ -31,9 +31,9 @@ import software.amazon.awssdk.http.AbortableInputStream; import software.amazon.awssdk.http.SdkHttpFullResponse; import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.http.async.AbortableInputStreamSubscriber; import software.amazon.awssdk.http.crt.AwsCrtHttpClient; import software.amazon.awssdk.utils.Logger; -import software.amazon.awssdk.utils.async.InputStreamSubscriber; import software.amazon.awssdk.utils.async.SimplePublisher; /** @@ -87,8 +87,10 @@ public void onResponseHeaders(HttpStream stream, int responseStatusCode, int blo @Override public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) { if (inputStreamSubscriber == null) { - inputStreamSubscriber = new AbortableInputStreamSubscriber(() -> responseHandlerHelper.closeConnection(stream), - new InputStreamSubscriber()); + inputStreamSubscriber = + AbortableInputStreamSubscriber.builder() + .doAfterClose(() -> responseHandlerHelper.closeConnection(stream)) + .build(); simplePublisher.subscribe(inputStreamSubscriber); // For response with a payload, we need to complete the future here to allow downstream to retrieve the data from // the stream directly. diff --git a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/ResponseHandler.java b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/ResponseHandler.java index 4d653a45f729..eb3ecd09eb3c 100644 --- a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/ResponseHandler.java +++ b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/ResponseHandler.java @@ -314,7 +314,7 @@ private void onCancel() { try { SdkCancellationException e = new SdkCancellationException( "Subscriber cancelled before all events were published"); - log.warn(channelContext.channel(), () -> "Subscriber cancelled before all events were published"); + log.debug(channelContext.channel(), () -> "Subscriber cancelled before all events were published"); executeFuture.completeExceptionally(e); } finally { runAndLogError(channelContext.channel(), () -> "Could not release channel back to the pool", diff --git a/test/codegen-generated-classes-test/pom.xml b/test/codegen-generated-classes-test/pom.xml index e88d0f537a24..e9fd9c2a6635 100644 --- a/test/codegen-generated-classes-test/pom.xml +++ b/test/codegen-generated-classes-test/pom.xml @@ -235,6 +235,12 @@ ${awsjavasdk.version} test + + software.amazon.awssdk + netty-nio-client + ${awsjavasdk.version} + test + io.reactivex.rxjava2 rxjava @@ -305,6 +311,17 @@ true + + + org.apache.maven.plugins + maven-dependency-plugin + ${maven-dependency-plugin.version} + + false + false + + + diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/BlockingAsyncRequestResponseBodyResourceManagementTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/BlockingAsyncRequestResponseBodyResourceManagementTest.java new file mode 100644 index 000000000000..8e5c61457859 --- /dev/null +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/BlockingAsyncRequestResponseBodyResourceManagementTest.java @@ -0,0 +1,203 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services; + +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE; +import static io.netty.handler.codec.http.HttpHeaderValues.TEXT_PLAIN; +import static io.netty.handler.codec.http.HttpResponseStatus.OK; +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.ServerSocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; +import java.io.IOException; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonAsyncClient; +import software.amazon.awssdk.services.protocolrestjson.model.StreamingOutputOperationRequest; +import software.amazon.awssdk.services.protocolrestjson.model.StreamingOutputOperationResponse; + +@Timeout(10) +public class BlockingAsyncRequestResponseBodyResourceManagementTest { + private ProtocolRestJsonAsyncClient client; + private Server server; + + + @AfterEach + void tearDownPerTest() throws InterruptedException { + server.shutdown(); + server = null; + client.close();; + + } + + @BeforeEach + void setUpPerTest() throws Exception { + server = new Server(); + server.init(); + + client = ProtocolRestJsonAsyncClient.builder() + .region(Region.US_WEST_2) + .credentialsProvider(AnonymousCredentialsProvider.create()) + .endpointOverride(URI.create("http://localhost:" + server.port())) + .overrideConfiguration(o -> o.retryPolicy(RetryPolicy.none())) + .build(); + } + + + @Test + void blockingResponseTransformer_abort_shouldCloseUnderlyingConnection() throws IOException { + verifyConnection(r -> r.abort()); + } + + @Test + void blockingResponseTransformer_close_shouldCloseUnderlyingConnection() throws IOException { + Consumer> closeInputStream = closeInputStraem(); + verifyConnection(closeInputStream); + } + + + private static Consumer> closeInputStraem() { + Consumer> closeInputStream = r -> { + try { + r.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + }; + return closeInputStream; + } + + + void verifyConnection(Consumer> consumer) throws IOException { + + CompletableFuture> responseFuture = + client.streamingOutputOperation(StreamingOutputOperationRequest.builder().build(), + AsyncResponseTransformer.toBlockingInputStream()); + ResponseInputStream responseStream = responseFuture.join(); + + + consumer.accept(responseStream); + + try { + client.headOperation().join(); + } catch (Exception exception) { + // Doesn't matter if the request succeeds or not + } + + // Total of 2 connections got established. + assertThat(server.channels.size()).isEqualTo(2); + } + + private static class Server extends ChannelInitializer { + private static final byte[] CONTENT = ("{ " + + "\"foo\": " + RandomStringUtils.randomAscii(1024) + + "}").getBytes(StandardCharsets.UTF_8); + private ServerBootstrap bootstrap; + private ServerSocketChannel serverSock; + private Set channels = ConcurrentHashMap.newKeySet(); + private final NioEventLoopGroup group = new NioEventLoopGroup(3); + + public void init() throws Exception { + bootstrap = new ServerBootstrap() + .channel(NioServerSocketChannel.class) + .group(group) + .childHandler(this); + + serverSock = (ServerSocketChannel) bootstrap.bind(0).sync().channel(); + } + + public void shutdown() throws InterruptedException { + group.shutdownGracefully().await(); + serverSock.close(); + } + + public int port() { + return serverSock.localAddress().getPort(); + } + + @Override + protected void initChannel(Channel ch) { + channels.add(ch); + ChannelPipeline pipeline = ch.pipeline(); + pipeline.addLast(new HttpServerCodec()); + pipeline.addLast(new BehaviorTestChannelHandler()); + pipeline.addLast(new LoggingHandler(LogLevel.INFO)); + } + + private class BehaviorTestChannelHandler extends ChannelDuplexHandler { + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + + if (!(msg instanceof HttpRequest)) { + return; + } + + HttpMethod method = ((HttpRequest) msg).method(); + + if (Objects.equals(method, HttpMethod.HEAD)) { + DefaultHttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, OK); + ctx.writeAndFlush(response); + return; + } + + FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, OK, + Unpooled.wrappedBuffer(CONTENT)); + + response.headers() + .set(CONTENT_TYPE, TEXT_PLAIN) + .setInt(CONTENT_LENGTH, response.content().readableBytes()); + + ctx.writeAndFlush(response); + } + } + } +} diff --git a/test/tests-coverage-reporting/pom.xml b/test/tests-coverage-reporting/pom.xml index 37219de349c2..4f578e4d4e39 100644 --- a/test/tests-coverage-reporting/pom.xml +++ b/test/tests-coverage-reporting/pom.xml @@ -291,6 +291,11 @@ imds ${awsjavasdk.version} + + software.amazon.awssdk + http-client-spi + ${awsjavasdk.version} +