diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChunkBuffer.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChunkBuffer.java index c171b0787678..bdf84d549b8f 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChunkBuffer.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChunkBuffer.java @@ -21,14 +21,16 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.utils.Logger; -import software.amazon.awssdk.utils.Validate; import software.amazon.awssdk.utils.builder.SdkBuilder; /** - * Class that will buffer incoming BufferBytes of totalBytes length to chunks of bufferSize* + * Class that will buffer incoming BufferBytes to chunks of bufferSize. + * If totalBytes is not provided, i.e. content-length is unknown, {@link #getBufferedData()} should be used in the Subscriber's + * {@code onComplete()} to check for a final chunk that is smaller than the chunk size, and send if present. */ @SdkInternalApi public final class ChunkBuffer { @@ -36,30 +38,26 @@ public final class ChunkBuffer { private final AtomicLong transferredBytes; private final ByteBuffer currentBuffer; private final int chunkSize; - private final long totalBytes; + private final Long totalBytes; private ChunkBuffer(Long totalBytes, Integer bufferSize) { - Validate.notNull(totalBytes, "The totalBytes must not be null"); - int chunkSize = bufferSize != null ? bufferSize : DEFAULT_ASYNC_CHUNK_SIZE; this.chunkSize = chunkSize; this.currentBuffer = ByteBuffer.allocate(chunkSize); - this.totalBytes = totalBytes; this.transferredBytes = new AtomicLong(0); + this.totalBytes = totalBytes; } public static Builder builder() { return new DefaultBuilder(); } - /** * Split the input {@link ByteBuffer} into multiple smaller {@link ByteBuffer}s, each of which contains {@link #chunkSize} * worth of bytes. If the last chunk of the input ByteBuffer contains less than {@link #chunkSize} data, the last chunk will * be buffered. */ public synchronized Iterable split(ByteBuffer inputByteBuffer) { - if (!inputByteBuffer.hasRemaining()) { return Collections.singletonList(inputByteBuffer); } @@ -71,7 +69,7 @@ public synchronized Iterable split(ByteBuffer inputByteBuffer) { fillCurrentBuffer(inputByteBuffer); if (isCurrentBufferFull()) { - addCurrentBufferToIterable(byteBuffers, chunkSize); + addCurrentBufferToIterable(byteBuffers); } } @@ -82,8 +80,7 @@ public synchronized Iterable split(ByteBuffer inputByteBuffer) { // If this is the last chunk, add data buffered to the iterable if (isLastChunk()) { - int remainingBytesInBuffer = currentBuffer.position(); - addCurrentBufferToIterable(byteBuffers, remainingBytesInBuffer); + addCurrentBufferToIterable(byteBuffers); } return byteBuffers; } @@ -111,19 +108,38 @@ private void splitRemainingInputByteBuffer(ByteBuffer inputByteBuffer, List getBufferedData() { + int remainingBytesInBuffer = currentBuffer.position(); + + if (remainingBytesInBuffer == 0) { + return Optional.empty(); + } + + ByteBuffer bufferedChunk = ByteBuffer.allocate(remainingBytesInBuffer); + currentBuffer.flip(); + bufferedChunk.put(currentBuffer); + bufferedChunk.flip(); + return Optional.of(bufferedChunk); + } + private boolean isLastChunk() { + if (totalBytes == null) { + return false; + } long remainingBytes = totalBytes - transferredBytes.get(); return remainingBytes != 0 && remainingBytes == currentBuffer.position(); } - private void addCurrentBufferToIterable(List byteBuffers, int capacity) { - ByteBuffer bufferedChunk = ByteBuffer.allocate(capacity); - currentBuffer.flip(); - bufferedChunk.put(currentBuffer); - bufferedChunk.flip(); - byteBuffers.add(bufferedChunk); - transferredBytes.addAndGet(bufferedChunk.remaining()); - currentBuffer.clear(); + private void addCurrentBufferToIterable(List byteBuffers) { + Optional bufferedChunk = getBufferedData(); + if (bufferedChunk.isPresent()) { + byteBuffers.add(bufferedChunk.get()); + transferredBytes.addAndGet(bufferedChunk.get().remaining()); + currentBuffer.clear(); + } } private void fillCurrentBuffer(ByteBuffer inputByteBuffer) { @@ -151,8 +167,6 @@ public interface Builder extends SdkBuilder { Builder bufferSize(int bufferSize); Builder totalBytes(long totalBytes); - - } private static final class DefaultBuilder implements Builder { diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/CompressionAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/CompressionAsyncRequestBody.java new file mode 100644 index 000000000000..82da601f0acc --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/CompressionAsyncRequestBody.java @@ -0,0 +1,212 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import static software.amazon.awssdk.core.internal.io.AwsChunkedInputStream.DEFAULT_CHUNK_SIZE; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.core.internal.compression.Compressor; +import software.amazon.awssdk.utils.Validate; +import software.amazon.awssdk.utils.async.DelegatingSubscriber; +import software.amazon.awssdk.utils.async.FlatteningSubscriber; +import software.amazon.awssdk.utils.builder.SdkBuilder; + +/** + * Wrapper class to wrap an AsyncRequestBody. + * This will chunk and compress the payload with the provided {@link Compressor}. + */ +@SdkInternalApi +public class CompressionAsyncRequestBody implements AsyncRequestBody { + + private final AsyncRequestBody wrapped; + private final Compressor compressor; + private final int chunkSize; + + private CompressionAsyncRequestBody(DefaultBuilder builder) { + this.wrapped = Validate.paramNotNull(builder.asyncRequestBody, "asyncRequestBody"); + this.compressor = Validate.paramNotNull(builder.compressor, "compressor"); + this.chunkSize = builder.chunkSize != null ? builder.chunkSize : DEFAULT_CHUNK_SIZE; + } + + @Override + public void subscribe(Subscriber s) { + Validate.notNull(s, "Subscription MUST NOT be null."); + + SdkPublisher> split = split(wrapped); + SdkPublisher flattening = flattening(split); + flattening.map(compressor::compress).subscribe(s); + } + + @Override + public Optional contentLength() { + return wrapped.contentLength(); + } + + @Override + public String contentType() { + return wrapped.contentType(); + } + + private SdkPublisher> split(SdkPublisher source) { + return subscriber -> source.subscribe(new SplittingSubscriber(subscriber, chunkSize)); + } + + private SdkPublisher flattening(SdkPublisher> source) { + return subscriber -> source.subscribe(new FlatteningSubscriber<>(subscriber)); + } + + /** + * @return Builder instance to construct a {@link CompressionAsyncRequestBody}. + */ + public static Builder builder() { + return new DefaultBuilder(); + } + + public interface Builder extends SdkBuilder { + + /** + * Sets the AsyncRequestBody that will be wrapped. + * @param asyncRequestBody + * @return This builder for method chaining. + */ + Builder asyncRequestBody(AsyncRequestBody asyncRequestBody); + + /** + * Sets the compressor to compress the request. + * @param compressor + * @return This builder for method chaining. + */ + Builder compressor(Compressor compressor); + + /** + * Sets the chunk size. Default size is 128 * 1024. + * @param chunkSize + * @return This builder for method chaining. + */ + Builder chunkSize(Integer chunkSize); + } + + private static final class DefaultBuilder implements Builder { + + private AsyncRequestBody asyncRequestBody; + private Compressor compressor; + private Integer chunkSize; + + @Override + public CompressionAsyncRequestBody build() { + return new CompressionAsyncRequestBody(this); + } + + @Override + public Builder asyncRequestBody(AsyncRequestBody asyncRequestBody) { + this.asyncRequestBody = asyncRequestBody; + return this; + } + + @Override + public Builder compressor(Compressor compressor) { + this.compressor = compressor; + return this; + } + + @Override + public Builder chunkSize(Integer chunkSize) { + this.chunkSize = chunkSize; + return this; + } + } + + private static final class SplittingSubscriber extends DelegatingSubscriber> { + private final ChunkBuffer chunkBuffer; + private final AtomicBoolean upstreamDone = new AtomicBoolean(false); + private final AtomicLong downstreamDemand = new AtomicLong(); + private final Object lock = new Object(); + private volatile boolean sentFinalChunk = false; + + protected SplittingSubscriber(Subscriber> subscriber, int chunkSize) { + super(subscriber); + this.chunkBuffer = ChunkBuffer.builder() + .bufferSize(chunkSize) + .build(); + } + + @Override + public void onSubscribe(Subscription s) { + subscriber.onSubscribe(new Subscription() { + @Override + public void request(long n) { + if (n <= 0) { + throw new IllegalArgumentException("n > 0 required but it was " + n); + } + + downstreamDemand.getAndAdd(n); + + if (upstreamDone.get()) { + sendFinalChunk(); + } else { + s.request(n); + } + } + + @Override + public void cancel() { + s.cancel(); + } + }); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + downstreamDemand.decrementAndGet(); + Iterable buffers = chunkBuffer.split(byteBuffer); + subscriber.onNext(buffers); + } + + @Override + public void onComplete() { + upstreamDone.compareAndSet(false, true); + if (downstreamDemand.get() > 0) { + sendFinalChunk(); + } + } + + @Override + public void onError(Throwable t) { + upstreamDone.compareAndSet(false, true); + super.onError(t); + } + + private void sendFinalChunk() { + synchronized (lock) { + if (!sentFinalChunk) { + sentFinalChunk = true; + Optional byteBuffer = chunkBuffer.getBufferedData(); + byteBuffer.ifPresent(buffer -> subscriber.onNext(Collections.singletonList(buffer))); + subscriber.onComplete(); + } + } + } + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/CompressRequestStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/CompressRequestStage.java index 1eadb88d32db..e002697c5f15 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/CompressRequestStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/CompressRequestStage.java @@ -30,6 +30,7 @@ import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.async.CompressionAsyncRequestBody; import software.amazon.awssdk.core.internal.compression.Compressor; import software.amazon.awssdk.core.internal.compression.CompressorType; import software.amazon.awssdk.core.internal.http.HttpClientDependencies; @@ -63,7 +64,6 @@ public SdkHttpFullRequest.Builder execute(SdkHttpFullRequest.Builder input, Requ Compressor compressor = resolveCompressorType(context.executionAttributes()); - // non-streaming if (!isStreaming(context)) { compressEntirePayload(input, compressor); updateContentEncodingHeader(input, compressor); @@ -76,12 +76,14 @@ public SdkHttpFullRequest.Builder execute(SdkHttpFullRequest.Builder input, Requ } if (context.requestProvider() == null) { - // sync streaming input.contentStreamProvider(new CompressionContentStreamProvider(input.contentStreamProvider(), compressor)); + } else { + context.requestProvider(CompressionAsyncRequestBody.builder() + .asyncRequestBody(context.requestProvider()) + .compressor(compressor) + .build()); } - // TODO : streaming - async - updateContentEncodingHeader(input, compressor); return input; } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/ChunkBufferTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/ChunkBufferTest.java index a553a55a4536..41250225664a 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/ChunkBufferTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/ChunkBufferTest.java @@ -16,7 +16,6 @@ package software.amazon.awssdk.core.async; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -24,6 +23,7 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; +import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -41,19 +41,53 @@ class ChunkBufferTest { - @Test - void builderWithNoTotalSize() { - assertThatThrownBy(() -> ChunkBuffer.builder().build()).isInstanceOf(NullPointerException.class); + @ParameterizedTest + @ValueSource(ints = {1, 6, 10, 23, 25}) + void numberOfChunk_Not_MultipleOfTotalBytes_KnownLength(int totalBytes) { + int bufferSize = 5; + + String inputString = RandomStringUtils.randomAscii(totalBytes); + ChunkBuffer chunkBuffer = ChunkBuffer.builder() + .bufferSize(bufferSize) + .totalBytes(inputString.getBytes(StandardCharsets.UTF_8).length) + .build(); + Iterable byteBuffers = + chunkBuffer.split(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8))); + + AtomicInteger index = new AtomicInteger(0); + int count = (int) Math.ceil(totalBytes / (double) bufferSize); + int remainder = totalBytes % bufferSize; + + byteBuffers.forEach(r -> { + int i = index.get(); + + try (ByteArrayInputStream inputStream = new ByteArrayInputStream(inputString.getBytes(StandardCharsets.UTF_8))) { + byte[] expected; + if (i == count - 1 && remainder != 0) { + expected = new byte[remainder]; + } else { + expected = new byte[bufferSize]; + } + inputStream.skip(i * bufferSize); + inputStream.read(expected); + byte[] actualBytes = BinaryUtils.copyBytesFrom(r); + assertThat(actualBytes).isEqualTo(expected); + index.incrementAndGet(); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); } @ParameterizedTest @ValueSource(ints = {1, 6, 10, 23, 25}) - void numberOfChunk_Not_MultipleOfTotalBytes(int totalBytes) { + void numberOfChunk_Not_MultipleOfTotalBytes_UnknownLength(int totalBytes) { int bufferSize = 5; String inputString = RandomStringUtils.randomAscii(totalBytes); - ChunkBuffer chunkBuffer = - ChunkBuffer.builder().bufferSize(bufferSize).totalBytes(inputString.getBytes(StandardCharsets.UTF_8).length).build(); + ChunkBuffer chunkBuffer = ChunkBuffer.builder() + .bufferSize(bufferSize) + .build(); Iterable byteBuffers = chunkBuffer.split(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8))); @@ -83,10 +117,12 @@ void numberOfChunk_Not_MultipleOfTotalBytes(int totalBytes) { } @Test - void zeroTotalBytesAsInput_returnsZeroByte() { + void zeroTotalBytesAsInput_returnsZeroByte_KnownLength() { byte[] zeroByte = new byte[0]; - ChunkBuffer chunkBuffer = - ChunkBuffer.builder().bufferSize(5).totalBytes(zeroByte.length).build(); + ChunkBuffer chunkBuffer = ChunkBuffer.builder() + .bufferSize(5) + .totalBytes(zeroByte.length) + .build(); Iterable byteBuffers = chunkBuffer.split(ByteBuffer.wrap(zeroByte)); @@ -98,13 +134,30 @@ void zeroTotalBytesAsInput_returnsZeroByte() { } @Test - void emptyAllocatedBytes_returnSameNumberOfEmptyBytes() { + void zeroTotalBytesAsInput_returnsZeroByte_UnknownLength() { + byte[] zeroByte = new byte[0]; + ChunkBuffer chunkBuffer = ChunkBuffer.builder() + .bufferSize(5) + .build(); + Iterable byteBuffers = + chunkBuffer.split(ByteBuffer.wrap(zeroByte)); + + AtomicInteger iteratedCounts = new AtomicInteger(); + byteBuffers.forEach(r -> { + iteratedCounts.getAndIncrement(); + }); + assertThat(iteratedCounts.get()).isEqualTo(1); + } + @Test + void emptyAllocatedBytes_returnSameNumberOfEmptyBytes_knownLength() { int totalBytes = 17; int bufferSize = 5; ByteBuffer wrap = ByteBuffer.allocate(totalBytes); - ChunkBuffer chunkBuffer = - ChunkBuffer.builder().bufferSize(bufferSize).totalBytes(wrap.remaining()).build(); + ChunkBuffer chunkBuffer = ChunkBuffer.builder() + .bufferSize(bufferSize) + .totalBytes(wrap.remaining()) + .build(); Iterable byteBuffers = chunkBuffer.split(wrap); @@ -121,6 +174,34 @@ void emptyAllocatedBytes_returnSameNumberOfEmptyBytes() { assertThat(iteratedCounts.get()).isEqualTo(4); } + @Test + void emptyAllocatedBytes_returnSameNumberOfEmptyBytes_unknownLength() { + int totalBytes = 17; + int bufferSize = 5; + ByteBuffer wrap = ByteBuffer.allocate(totalBytes); + ChunkBuffer chunkBuffer = ChunkBuffer.builder() + .bufferSize(bufferSize) + .build(); + Iterable byteBuffers = + chunkBuffer.split(wrap); + + AtomicInteger iteratedCounts = new AtomicInteger(); + byteBuffers.forEach(r -> { + iteratedCounts.getAndIncrement(); + if (iteratedCounts.get() * bufferSize < totalBytes) { + // array of empty bytes + assertThat(BinaryUtils.copyBytesFrom(r)).isEqualTo(ByteBuffer.allocate(bufferSize).array()); + } else { + assertThat(BinaryUtils.copyBytesFrom(r)).isEqualTo(ByteBuffer.allocate(totalBytes % bufferSize).array()); + } + }); + assertThat(iteratedCounts.get()).isEqualTo(3); + + Optional lastBuffer = chunkBuffer.getBufferedData(); + assertThat(lastBuffer.isPresent()); + assertThat(lastBuffer.get().remaining()).isEqualTo(2); + } + /** * * Total bytes 11(ChunkSize) 3 (threads) @@ -152,14 +233,16 @@ void emptyAllocatedBytes_returnSameNumberOfEmptyBytes() { * 111 is given as output since we consumed all the total bytes* */ @Test - void concurrentTreads_calling_bufferAndCreateChunks() throws ExecutionException, InterruptedException { + void concurrentTreads_calling_bufferAndCreateChunks_knownLength() throws ExecutionException, InterruptedException { int totalBytes = 17; int bufferSize = 5; int threads = 8; ByteBuffer wrap = ByteBuffer.allocate(totalBytes); - ChunkBuffer chunkBuffer = - ChunkBuffer.builder().bufferSize(bufferSize).totalBytes(wrap.remaining() * threads).build(); + ChunkBuffer chunkBuffer = ChunkBuffer.builder() + .bufferSize(bufferSize) + .totalBytes(wrap.remaining() * threads) + .build(); ExecutorService service = Executors.newFixedThreadPool(threads); @@ -198,7 +281,4 @@ void concurrentTreads_calling_bufferAndCreateChunks() throws ExecutionException, assertThat(remainderBytesBuffers.get()).isOne(); assertThat(otherSizeBuffers.get()).isZero(); } - } - - diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/CompressionAsyncRequestBodyTckTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/CompressionAsyncRequestBodyTckTest.java new file mode 100644 index 000000000000..54c74e1e97e9 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/CompressionAsyncRequestBodyTckTest.java @@ -0,0 +1,111 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.async; + +import com.google.common.jimfs.Configuration; +import com.google.common.jimfs.Jimfs; +import io.reactivex.Flowable; +import java.io.IOException; +import java.io.OutputStream; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.file.FileSystem; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Optional; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.tck.PublisherVerification; +import org.reactivestreams.tck.TestEnvironment; +import software.amazon.awssdk.core.internal.async.CompressionAsyncRequestBody; +import software.amazon.awssdk.core.internal.compression.Compressor; +import software.amazon.awssdk.core.internal.compression.GzipCompressor; + +public class CompressionAsyncRequestBodyTckTest extends PublisherVerification { + + private static final FileSystem fs = Jimfs.newFileSystem(Configuration.unix()); + private static final Path rootDir = fs.getRootDirectories().iterator().next(); + private static final int MAX_ELEMENTS = 1000; + private static final int CHUNK_SIZE = 128 * 1024; + private static final Compressor compressor = new GzipCompressor(); + + public CompressionAsyncRequestBodyTckTest() { + super(new TestEnvironment()); + } + + @Override + public long maxElementsFromPublisher() { + return MAX_ELEMENTS; + } + + @Override + public Publisher createPublisher(long n) { + return CompressionAsyncRequestBody.builder() + .asyncRequestBody(customAsyncRequestBodyFromFileWithoutContentLength(n)) + .compressor(compressor) + .build(); + } + + @Override + public Publisher createFailedPublisher() { + return null; + } + + private static AsyncRequestBody customAsyncRequestBodyFromFileWithoutContentLength(long nChunks) { + return new AsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.empty(); + } + + @Override + public void subscribe(Subscriber s) { + Flowable.fromPublisher(AsyncRequestBody.fromFile(fileOfNChunks(nChunks))).subscribe(s); + } + }; + } + + private static Path fileOfNChunks(long nChunks) { + String name = String.format("%d-chunks-file.dat", nChunks); + Path p = rootDir.resolve(name); + if (!Files.exists(p)) { + try (OutputStream os = Files.newOutputStream(p)) { + os.write(createCompressibleArrayOfNChunks(nChunks)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + return p; + } + + private static byte[] createCompressibleArrayOfNChunks(long nChunks) { + int size = Math.toIntExact(nChunks * CHUNK_SIZE); + ByteBuffer data = ByteBuffer.allocate(size); + + byte[] a = new byte[size / 4]; + byte[] b = new byte[size / 4]; + Arrays.fill(a, (byte) 'a'); + Arrays.fill(b, (byte) 'b'); + + data.put(a); + data.put(b); + data.put(a); + data.put(b); + + return data.array(); + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/CompressionAsyncRequestBodyTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/CompressionAsyncRequestBodyTest.java new file mode 100644 index 000000000000..ffb15e282a13 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/CompressionAsyncRequestBodyTest.java @@ -0,0 +1,173 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.reactivex.Flowable; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.zip.GZIPInputStream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Subscriber; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.internal.compression.Compressor; +import software.amazon.awssdk.core.internal.compression.GzipCompressor; +import software.amazon.awssdk.core.internal.util.Mimetype; +import software.amazon.awssdk.http.async.SimpleSubscriber; + +public final class CompressionAsyncRequestBodyTest { + private static final Compressor compressor = new GzipCompressor(); + + @ParameterizedTest + @ValueSource(ints = {80, 1000}) + public void hasCorrectContent(int bodySize) throws Exception { + String testString = createCompressibleStringOfGivenSize(bodySize); + byte[] testBytes = testString.getBytes(); + int chunkSize = 133; + AsyncRequestBody provider = CompressionAsyncRequestBody.builder() + .compressor(compressor) + .asyncRequestBody(customAsyncRequestBodyWithoutContentLength(testBytes)) + .chunkSize(chunkSize) + .build(); + + ByteBuffer byteBuffer = ByteBuffer.allocate(testString.length()); + CountDownLatch done = new CountDownLatch(1); + AtomicInteger pos = new AtomicInteger(); + + Subscriber subscriber = new SimpleSubscriber(buffer -> { + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + byteBuffer.put(bytes); + + // verify each chunk + byte[] chunkToVerify = new byte[chunkSize]; + System.arraycopy(testBytes, pos.get(), chunkToVerify, 0, chunkSize); + chunkToVerify = compressor.compress(chunkToVerify); + + assertThat(bytes).isEqualTo(chunkToVerify); + pos.addAndGet(chunkSize); + }) { + @Override + public void onError(Throwable t) { + super.onError(t); + done.countDown(); + } + + @Override + public void onComplete() { + super.onComplete(); + done.countDown(); + } + }; + + provider.subscribe(subscriber); + done.await(10, TimeUnit.SECONDS); + + byte[] retrieved = byteBuffer.array(); + byte[] uncompressed = decompress(retrieved); + assertThat(new String(uncompressed)).isEqualTo(testString); + } + + @Test + public void emptyBytesConstructor_hasEmptyContent() throws Exception { + AsyncRequestBody requestBody = CompressionAsyncRequestBody.builder() + .compressor(compressor) + .asyncRequestBody(AsyncRequestBody.empty()) + .build(); + + ByteBuffer byteBuffer = ByteBuffer.allocate(0); + CountDownLatch done = new CountDownLatch(1); + + Subscriber subscriber = new SimpleSubscriber(buffer -> { + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + byteBuffer.put(bytes); + }) { + @Override + public void onError(Throwable t) { + super.onError(t); + done.countDown(); + } + + @Override + public void onComplete() { + super.onComplete(); + done.countDown(); + } + }; + + requestBody.subscribe(subscriber); + done.await(10, TimeUnit.SECONDS); + assertThat(byteBuffer.array()).isEmpty(); + assertThat(byteBuffer.array()).isEqualTo(new byte[0]); + assertThat(requestBody.contentType()).isEqualTo(Mimetype.MIMETYPE_OCTET_STREAM); + } + + private static String createCompressibleStringOfGivenSize(int size) { + ByteBuffer data = ByteBuffer.allocate(size); + + byte[] a = new byte[size / 4]; + byte[] b = new byte[size / 4]; + Arrays.fill(a, (byte) 'a'); + Arrays.fill(b, (byte) 'b'); + + data.put(a); + data.put(b); + data.put(a); + data.put(b); + + return new String(data.array()); + } + + private static byte[] decompress(byte[] compressedData) throws IOException { + ByteArrayInputStream bais = new ByteArrayInputStream(compressedData); + GZIPInputStream gzipInputStream = new GZIPInputStream(bais); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + byte[] buffer = new byte[1024]; + int bytesRead; + while ((bytesRead = gzipInputStream.read(buffer)) != -1) { + baos.write(buffer, 0, bytesRead); + } + gzipInputStream.close(); + byte[] decompressedData = baos.toByteArray(); + return decompressedData; + } + + private static AsyncRequestBody customAsyncRequestBodyWithoutContentLength(byte[] content) { + return new AsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.empty(); + } + + @Override + public void subscribe(Subscriber s) { + Flowable.fromPublisher(AsyncRequestBody.fromBytes(content)) + .subscribe(s); + } + }; + } +} \ No newline at end of file diff --git a/services/mediastoredata/src/it/java/software/amazon/awssdk/services/mediastoredata/MediaStoreDataIntegrationTestBase.java b/services/mediastoredata/src/it/java/software/amazon/awssdk/services/mediastoredata/MediaStoreDataIntegrationTestBase.java index 3a0e7006ef8b..20688925bc86 100644 --- a/services/mediastoredata/src/it/java/software/amazon/awssdk/services/mediastoredata/MediaStoreDataIntegrationTestBase.java +++ b/services/mediastoredata/src/it/java/software/amazon/awssdk/services/mediastoredata/MediaStoreDataIntegrationTestBase.java @@ -83,7 +83,7 @@ private static DescribeContainerResponse waitContainerToBeActive() { .orFailAfter(Duration.ofMinutes(3)); } - protected AsyncRequestBody customAsyncRequestBodyWithoutContentLength() { + protected AsyncRequestBody customAsyncRequestBodyWithoutContentLength(byte[] body) { return new AsyncRequestBody() { @Override public Optional contentLength() { @@ -92,7 +92,7 @@ public Optional contentLength() { @Override public void subscribe(Subscriber s) { - Flowable.fromPublisher(AsyncRequestBody.fromBytes("Random text".getBytes())) + Flowable.fromPublisher(AsyncRequestBody.fromBytes(body)) .subscribe(s); } }; diff --git a/services/mediastoredata/src/it/java/software/amazon/awssdk/services/mediastoredata/RequestCompressionStreamingIntegrationTest.java b/services/mediastoredata/src/it/java/software/amazon/awssdk/services/mediastoredata/RequestCompressionStreamingIntegrationTest.java index 9530f2319b38..bb4a2a9bf0c3 100644 --- a/services/mediastoredata/src/it/java/software/amazon/awssdk/services/mediastoredata/RequestCompressionStreamingIntegrationTest.java +++ b/services/mediastoredata/src/it/java/software/amazon/awssdk/services/mediastoredata/RequestCompressionStreamingIntegrationTest.java @@ -27,6 +27,7 @@ import software.amazon.awssdk.core.RequestCompressionConfiguration; import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; @@ -83,7 +84,7 @@ public static void setup() { asyncClient = MediaStoreDataAsyncClient.builder() .endpointOverride(uri) - .credentialsProvider(getCredentialsProvider()) + .credentialsProvider(credentialsProvider) .httpClient(NettyNioAsyncHttpClient.create()) .overrideConfiguration(o -> o.addExecutionInterceptor(new CaptureTransferEncodingHeaderInterceptor()) .addExecutionInterceptor(new CaptureContentEncodingHeaderInterceptor()) @@ -108,11 +109,13 @@ public static void setup() { } @AfterAll - public static void tearDown() { + public static void tearDown() throws InterruptedException { syncClient.deleteObject(deleteObjectRequest); Waiter.run(() -> syncClient.describeObject(r -> r.path("/foo"))) .untilException(ObjectNotFoundException.class) .orFailAfter(Duration.ofMinutes(1)); + Thread.sleep(1000); + mediaStoreClient.deleteContainer(r -> r.containerName(CONTAINER_NAME)); } @AfterEach @@ -121,7 +124,7 @@ public void cleanUp() { } @Test - public void putObject_withRequestCompressionSyncStreaming_compressesPayloadAndSendsCorrectly() throws IOException { + public void putObject_withSyncStreamingRequestCompression_compressesPayloadAndSendsCorrectly() throws IOException { TestContentProvider provider = new TestContentProvider(UNCOMPRESSED_BODY.getBytes(StandardCharsets.UTF_8)); syncClient.putObject(putObjectRequest, RequestBody.fromContentProvider(provider, "binary/octet-stream")); @@ -129,29 +132,26 @@ public void putObject_withRequestCompressionSyncStreaming_compressesPayloadAndSe assertThat(CaptureContentEncodingHeaderInterceptor.isGzip).isTrue(); ResponseInputStream response = syncClient.getObject(getObjectRequest); - byte[] buffer = new byte[UNCOMPRESSED_BODY.getBytes(StandardCharsets.UTF_8).length]; + byte[] buffer = new byte[UNCOMPRESSED_BODY.getBytes().length]; response.read(buffer); String retrievedContent = new String(buffer); - assertThat(UNCOMPRESSED_BODY).isEqualTo(retrievedContent); + assertThat(retrievedContent).isEqualTo(UNCOMPRESSED_BODY); } - // TODO : uncomment once async streaming compression is implemented - /*@Test - public void nettyClientPutObject_withoutContentLength_sendsSuccessfully() throws IOException { - AsyncRequestBody asyncRequestBody = customAsyncRequestBodyWithoutContentLength(); + @Test + public void putObject_withAsyncStreamingRequestCompression_compressesPayloadAndSendsCorrectly() throws IOException { + AsyncRequestBody asyncRequestBody = customAsyncRequestBodyWithoutContentLength(UNCOMPRESSED_BODY.getBytes()); asyncClient.putObject(putObjectRequest, asyncRequestBody).join(); assertThat(CaptureTransferEncodingHeaderInterceptor.isChunked).isTrue(); assertThat(CaptureContentEncodingHeaderInterceptor.isGzip).isTrue(); - // verify stored content is correct ResponseInputStream response = syncClient.getObject(getObjectRequest); - byte[] buffer = new byte[UNCOMPRESSED_BODY.getBytes(StandardCharsets.UTF_8).length]; + byte[] buffer = new byte[UNCOMPRESSED_BODY.getBytes().length]; response.read(buffer); String retrievedContent = new String(buffer); - assertThat(UNCOMPRESSED_BODY).isEqualTo(retrievedContent); - assertThat(CaptureTransferEncodingHeaderInterceptor.isChunked).isTrue(); - }*/ + assertThat(retrievedContent).isEqualTo(UNCOMPRESSED_BODY); + } private static class CaptureContentEncodingHeaderInterceptor implements ExecutionInterceptor { public static boolean isGzip; diff --git a/services/mediastoredata/src/it/java/software/amazon/awssdk/services/mediastoredata/TransferEncodingChunkedIntegrationTest.java b/services/mediastoredata/src/it/java/software/amazon/awssdk/services/mediastoredata/TransferEncodingChunkedIntegrationTest.java index 80fb67dc6fab..34522618f759 100644 --- a/services/mediastoredata/src/it/java/software/amazon/awssdk/services/mediastoredata/TransferEncodingChunkedIntegrationTest.java +++ b/services/mediastoredata/src/it/java/software/amazon/awssdk/services/mediastoredata/TransferEncodingChunkedIntegrationTest.java @@ -76,11 +76,12 @@ public static void setup() { } @AfterAll - public static void tearDown() { + public static void tearDown() throws InterruptedException { syncClientWithApache.deleteObject(deleteObjectRequest); Waiter.run(() -> syncClientWithApache.describeObject(r -> r.path("/foo"))) .untilException(ObjectNotFoundException.class) .orFailAfter(Duration.ofMinutes(1)); + Thread.sleep(500); mediaStoreClient.deleteContainer(r -> r.containerName(CONTAINER_NAME)); } @@ -100,7 +101,7 @@ public void urlConnectionClientPutObject_withoutContentLength_sendsSuccessfully( @Test public void nettyClientPutObject_withoutContentLength_sendsSuccessfully() { - asyncClientWithNetty.putObject(putObjectRequest, customAsyncRequestBodyWithoutContentLength()).join(); + asyncClientWithNetty.putObject(putObjectRequest, customAsyncRequestBodyWithoutContentLength("TestBody".getBytes())).join(); assertThat(CaptureTransferEncodingHeaderInterceptor.isChunked).isTrue(); } } diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AsyncRequestCompressionTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AsyncRequestCompressionTest.java new file mode 100644 index 000000000000..5a8f1f50dbc8 --- /dev/null +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AsyncRequestCompressionTest.java @@ -0,0 +1,205 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.reactivex.Flowable; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.Optional; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscriber; +import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.internal.compression.Compressor; +import software.amazon.awssdk.core.internal.compression.GzipCompressor; +import software.amazon.awssdk.http.HttpExecuteResponse; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonAsyncClient; +import software.amazon.awssdk.services.protocolrestjson.model.PutOperationWithRequestCompressionRequest; +import software.amazon.awssdk.services.protocolrestjson.model.PutOperationWithStreamingRequestCompressionRequest; +import software.amazon.awssdk.testutils.service.http.MockAsyncHttpClient; + +public class AsyncRequestCompressionTest { + private static final String UNCOMPRESSED_BODY = + "RequestCompressionTest-RequestCompressionTest-RequestCompressionTest-RequestCompressionTest-RequestCompressionTest"; + private String compressedBody; + private int compressedLen; + private MockAsyncHttpClient mockAsyncHttpClient; + private ProtocolRestJsonAsyncClient asyncClient; + private Compressor compressor; + + @BeforeEach + public void setUp() { + mockAsyncHttpClient = new MockAsyncHttpClient(); + asyncClient = ProtocolRestJsonAsyncClient.builder() + .credentialsProvider(AnonymousCredentialsProvider.create()) + .region(Region.US_EAST_1) + .httpClient(mockAsyncHttpClient) + .build(); + compressor = new GzipCompressor(); + byte[] compressedBodyBytes = compressor.compress(UNCOMPRESSED_BODY.getBytes()); + compressedBody = new String(compressedBodyBytes); + compressedLen = compressedBodyBytes.length; + } + + @AfterEach + public void reset() { + mockAsyncHttpClient.reset(); + } + + @Test + public void asyncNonStreamingOperation_compressionEnabledThresholdOverridden_compressesCorrectly() { + mockAsyncHttpClient.stubNextResponse(mockResponse(), Duration.ofMillis(500)); + + PutOperationWithRequestCompressionRequest request = + PutOperationWithRequestCompressionRequest.builder() + .body(SdkBytes.fromUtf8String(UNCOMPRESSED_BODY)) + .overrideConfiguration(o -> o.requestCompressionConfiguration( + c -> c.minimumCompressionThresholdInBytes(1))) + .build(); + + asyncClient.putOperationWithRequestCompression(request); + + SdkHttpFullRequest loggedRequest = (SdkHttpFullRequest) mockAsyncHttpClient.getLastRequest(); + InputStream loggedStream = loggedRequest.contentStreamProvider().get().newStream(); + String loggedBody = new String(SdkBytes.fromInputStream(loggedStream).asByteArray()); + int loggedSize = Integer.valueOf(loggedRequest.firstMatchingHeader("Content-Length").get()); + + assertThat(loggedBody).isEqualTo(compressedBody); + assertThat(loggedSize).isEqualTo(compressedLen); + assertThat(loggedRequest.firstMatchingHeader("Content-encoding").get()).isEqualTo("gzip"); + } + + @Test + public void asyncNonStreamingOperation_payloadSizeLessThanCompressionThreshold_doesNotCompress() { + mockAsyncHttpClient.stubNextResponse(mockResponse(), Duration.ofMillis(500)); + + PutOperationWithRequestCompressionRequest request = + PutOperationWithRequestCompressionRequest.builder() + .body(SdkBytes.fromUtf8String(UNCOMPRESSED_BODY)) + .build(); + + asyncClient.putOperationWithRequestCompression(request); + + SdkHttpFullRequest loggedRequest = (SdkHttpFullRequest) mockAsyncHttpClient.getLastRequest(); + InputStream loggedStream = loggedRequest.contentStreamProvider().get().newStream(); + String loggedBody = new String(SdkBytes.fromInputStream(loggedStream).asByteArray()); + int loggedSize = Integer.valueOf(loggedRequest.firstMatchingHeader("Content-Length").get()); + + assertThat(loggedBody).isEqualTo(UNCOMPRESSED_BODY); + assertThat(loggedSize).isEqualTo(UNCOMPRESSED_BODY.length()); + assertThat(loggedRequest.firstMatchingHeader("Content-encoding")).isEmpty(); + } + + @Test + public void asyncStreamingOperation_compressionEnabled_compressesCorrectly() { + mockAsyncHttpClient.stubNextResponse(mockResponse(), Duration.ofMillis(500)); + + mockAsyncHttpClient.setAsyncRequestBodyLength(compressedBody.length()); + PutOperationWithStreamingRequestCompressionRequest request = + PutOperationWithStreamingRequestCompressionRequest.builder().build(); + asyncClient.putOperationWithStreamingRequestCompression(request, customAsyncRequestBodyWithoutContentLength(), + AsyncResponseTransformer.toBytes()).join(); + + SdkHttpFullRequest loggedRequest = (SdkHttpFullRequest) mockAsyncHttpClient.getLastRequest(); + String loggedBody = new String(mockAsyncHttpClient.getStreamingPayload()); + + assertThat(loggedBody).isEqualTo(compressedBody); + assertThat(loggedRequest.firstMatchingHeader("Content-encoding").get()).isEqualTo("gzip"); + assertThat(loggedRequest.matchingHeaders("Content-Length")).isEmpty(); + assertThat(loggedRequest.firstMatchingHeader("Transfer-Encoding").get()).isEqualTo("chunked"); + } + + @Test + public void asyncNonStreamingOperation_compressionEnabledThresholdOverriddenWithRetry_compressesCorrectly() { + mockAsyncHttpClient.stubNextResponse(mockErrorResponse(), Duration.ofMillis(500)); + mockAsyncHttpClient.stubNextResponse(mockResponse(), Duration.ofMillis(500)); + + PutOperationWithRequestCompressionRequest request = + PutOperationWithRequestCompressionRequest.builder() + .body(SdkBytes.fromUtf8String(UNCOMPRESSED_BODY)) + .overrideConfiguration(o -> o.requestCompressionConfiguration( + c -> c.minimumCompressionThresholdInBytes(1))) + .build(); + + asyncClient.putOperationWithRequestCompression(request); + + SdkHttpFullRequest loggedRequest = (SdkHttpFullRequest) mockAsyncHttpClient.getLastRequest(); + InputStream loggedStream = loggedRequest.contentStreamProvider().get().newStream(); + String loggedBody = new String(SdkBytes.fromInputStream(loggedStream).asByteArray()); + int loggedSize = Integer.valueOf(loggedRequest.firstMatchingHeader("Content-Length").get()); + + assertThat(loggedBody).isEqualTo(compressedBody); + assertThat(loggedSize).isEqualTo(compressedLen); + assertThat(loggedRequest.firstMatchingHeader("Content-encoding").get()).isEqualTo("gzip"); + } + + @Test + public void asyncStreamingOperation_compressionEnabledWithRetry_compressesCorrectly() { + mockAsyncHttpClient.stubNextResponse(mockResponse(), Duration.ofMillis(500)); + mockAsyncHttpClient.stubNextResponse(mockResponse(), Duration.ofMillis(500)); + + mockAsyncHttpClient.setAsyncRequestBodyLength(compressedBody.length()); + PutOperationWithStreamingRequestCompressionRequest request = + PutOperationWithStreamingRequestCompressionRequest.builder().build(); + asyncClient.putOperationWithStreamingRequestCompression(request, customAsyncRequestBodyWithoutContentLength(), + AsyncResponseTransformer.toBytes()).join(); + + SdkHttpFullRequest loggedRequest = (SdkHttpFullRequest) mockAsyncHttpClient.getLastRequest(); + String loggedBody = new String(mockAsyncHttpClient.getStreamingPayload()); + + assertThat(loggedBody).isEqualTo(compressedBody); + assertThat(loggedRequest.firstMatchingHeader("Content-encoding").get()).isEqualTo("gzip"); + assertThat(loggedRequest.matchingHeaders("Content-Length")).isEmpty(); + assertThat(loggedRequest.firstMatchingHeader("Transfer-Encoding").get()).isEqualTo("chunked"); + } + + private HttpExecuteResponse mockResponse() { + return HttpExecuteResponse.builder() + .response(SdkHttpResponse.builder().statusCode(200).build()) + .build(); + } + + private HttpExecuteResponse mockErrorResponse() { + return HttpExecuteResponse.builder() + .response(SdkHttpResponse.builder().statusCode(500).build()) + .build(); + } + + protected AsyncRequestBody customAsyncRequestBodyWithoutContentLength() { + return new AsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.empty(); + } + + @Override + public void subscribe(Subscriber s) { + Flowable.fromPublisher(AsyncRequestBody.fromBytes(UNCOMPRESSED_BODY.getBytes())) + .subscribe(s); + } + }; + } +} diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/RequestCompressionTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/RequestCompressionTest.java index 29664c5f53f3..14cb07125037 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/RequestCompressionTest.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/RequestCompressionTest.java @@ -22,7 +22,6 @@ import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; -import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.ArrayList; import java.util.List; @@ -40,11 +39,9 @@ import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpResponse; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonAsyncClient; import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClient; import software.amazon.awssdk.services.protocolrestjson.model.PutOperationWithRequestCompressionRequest; import software.amazon.awssdk.services.protocolrestjson.model.PutOperationWithStreamingRequestCompressionRequest; -import software.amazon.awssdk.testutils.service.http.MockAsyncHttpClient; import software.amazon.awssdk.testutils.service.http.MockSyncHttpClient; public class RequestCompressionTest { @@ -53,42 +50,33 @@ public class RequestCompressionTest { private String compressedBody; private int compressedLen; private MockSyncHttpClient mockHttpClient; - private MockAsyncHttpClient mockAsyncHttpClient; private ProtocolRestJsonClient syncClient; - private ProtocolRestJsonAsyncClient asyncClient; private Compressor compressor; private RequestBody requestBody; @BeforeEach public void setUp() { mockHttpClient = new MockSyncHttpClient(); - mockAsyncHttpClient = new MockAsyncHttpClient(); syncClient = ProtocolRestJsonClient.builder() .credentialsProvider(AnonymousCredentialsProvider.create()) .region(Region.US_EAST_1) .httpClient(mockHttpClient) .build(); - asyncClient = ProtocolRestJsonAsyncClient.builder() - .credentialsProvider(AnonymousCredentialsProvider.create()) - .region(Region.US_EAST_1) - .httpClient(mockAsyncHttpClient) - .build(); compressor = new GzipCompressor(); - byte[] compressedBodyBytes = compressor.compress(SdkBytes.fromUtf8String(UNCOMPRESSED_BODY)).asByteArray(); + byte[] compressedBodyBytes = compressor.compress(UNCOMPRESSED_BODY.getBytes()); compressedLen = compressedBodyBytes.length; compressedBody = new String(compressedBodyBytes); - TestContentProvider provider = new TestContentProvider(UNCOMPRESSED_BODY.getBytes(StandardCharsets.UTF_8)); + TestContentProvider provider = new TestContentProvider(UNCOMPRESSED_BODY.getBytes()); requestBody = RequestBody.fromContentProvider(provider, "binary/octet-stream"); } @AfterEach public void reset() { mockHttpClient.reset(); - mockAsyncHttpClient.reset(); } @Test - public void sync_nonStreaming_compression_compressesCorrectly() { + public void syncNonStreamingOperation_compressionEnabledThresholdOverridden_compressesCorrectly() { mockHttpClient.stubNextResponse(mockResponse(), Duration.ofMillis(500)); PutOperationWithRequestCompressionRequest request = @@ -110,30 +98,25 @@ public void sync_nonStreaming_compression_compressesCorrectly() { } @Test - public void async_nonStreaming_compression_compressesCorrectly() { - mockAsyncHttpClient.stubNextResponse(mockResponse(), Duration.ofMillis(500)); + public void syncNonStreamingOperation_payloadSizeLessThanCompressionThreshold_doesNotCompress() { + mockHttpClient.stubNextResponse(mockResponse(), Duration.ofMillis(500)); PutOperationWithRequestCompressionRequest request = PutOperationWithRequestCompressionRequest.builder() .body(SdkBytes.fromUtf8String(UNCOMPRESSED_BODY)) - .overrideConfiguration(o -> o.requestCompressionConfiguration( - c -> c.minimumCompressionThresholdInBytes(1))) .build(); + syncClient.putOperationWithRequestCompression(request); - asyncClient.putOperationWithRequestCompression(request); - - SdkHttpFullRequest loggedRequest = (SdkHttpFullRequest) mockAsyncHttpClient.getLastRequest(); + SdkHttpFullRequest loggedRequest = (SdkHttpFullRequest) mockHttpClient.getLastRequest(); InputStream loggedStream = loggedRequest.contentStreamProvider().get().newStream(); String loggedBody = new String(SdkBytes.fromInputStream(loggedStream).asByteArray()); - int loggedSize = Integer.valueOf(loggedRequest.firstMatchingHeader("Content-Length").get()); - assertThat(loggedBody).isEqualTo(compressedBody); - assertThat(loggedSize).isEqualTo(compressedLen); - assertThat(loggedRequest.firstMatchingHeader("Content-encoding").get()).isEqualTo("gzip"); + assertThat(loggedBody).isEqualTo(UNCOMPRESSED_BODY); + assertThat(loggedRequest.firstMatchingHeader("Content-encoding")).isEmpty(); } @Test - public void sync_streaming_compression_compressesCorrectly() { + public void syncStreamingOperation_compressionEnabled_compressesCorrectly() { mockHttpClient.stubNextResponse(mockResponse(), Duration.ofMillis(500)); PutOperationWithStreamingRequestCompressionRequest request = @@ -151,7 +134,7 @@ public void sync_streaming_compression_compressesCorrectly() { } @Test - public void sync_nonStreaming_compression_withRetry_compressesCorrectly() { + public void syncNonStreamingOperation_compressionEnabledThresholdOverriddenWithRetry_compressesCorrectly() { mockHttpClient.stubNextResponse(mockErrorResponse(), Duration.ofMillis(500)); mockHttpClient.stubNextResponse(mockResponse(), Duration.ofMillis(500)); @@ -174,31 +157,7 @@ public void sync_nonStreaming_compression_withRetry_compressesCorrectly() { } @Test - public void async_nonStreaming_compression_withRetry_compressesCorrectly() { - mockAsyncHttpClient.stubNextResponse(mockErrorResponse(), Duration.ofMillis(500)); - mockAsyncHttpClient.stubNextResponse(mockResponse(), Duration.ofMillis(500)); - - PutOperationWithRequestCompressionRequest request = - PutOperationWithRequestCompressionRequest.builder() - .body(SdkBytes.fromUtf8String(UNCOMPRESSED_BODY)) - .overrideConfiguration(o -> o.requestCompressionConfiguration( - c -> c.minimumCompressionThresholdInBytes(1))) - .build(); - - asyncClient.putOperationWithRequestCompression(request); - - SdkHttpFullRequest loggedRequest = (SdkHttpFullRequest) mockAsyncHttpClient.getLastRequest(); - InputStream loggedStream = loggedRequest.contentStreamProvider().get().newStream(); - String loggedBody = new String(SdkBytes.fromInputStream(loggedStream).asByteArray()); - int loggedSize = Integer.valueOf(loggedRequest.firstMatchingHeader("Content-Length").get()); - - assertThat(loggedBody).isEqualTo(compressedBody); - assertThat(loggedSize).isEqualTo(compressedLen); - assertThat(loggedRequest.firstMatchingHeader("Content-encoding").get()).isEqualTo("gzip"); - } - - @Test - public void sync_streaming_compression_withRetry_compressesCorrectly() { + public void syncStreamingOperation_compressionEnabledWithRetry_compressesCorrectly() { mockHttpClient.stubNextResponse(mockErrorResponse(), Duration.ofMillis(500)); mockHttpClient.stubNextResponse(mockResponse(), Duration.ofMillis(500)); diff --git a/test/service-test-utils/src/main/java/software/amazon/awssdk/testutils/service/http/MockAsyncHttpClient.java b/test/service-test-utils/src/main/java/software/amazon/awssdk/testutils/service/http/MockAsyncHttpClient.java index 8a3f62f7838e..16a7732cb186 100644 --- a/test/service-test-utils/src/main/java/software/amazon/awssdk/testutils/service/http/MockAsyncHttpClient.java +++ b/test/service-test-utils/src/main/java/software/amazon/awssdk/testutils/service/http/MockAsyncHttpClient.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicInteger; @@ -50,6 +51,8 @@ public final class MockAsyncHttpClient implements SdkAsyncHttpClient, MockHttpCl private final List> responses = new LinkedList<>(); private final AtomicInteger responseIndex = new AtomicInteger(0); private final ExecutorService executor; + private int asyncRequestBodyLength = -1; + private byte[] streamingPayload; public MockAsyncHttpClient() { this.executor = Executors.newFixedThreadPool(3); @@ -66,6 +69,11 @@ public CompletableFuture execute(AsyncExecuteRequest request) { request.responseHandler().onHeaders(nextResponse.httpResponse()); CompletableFuture.runAsync(() -> request.responseHandler().onStream(new ResponsePublisher(content, index)), executor); + + if (asyncRequestBodyLength > 0) { + captureStreamingPayload(request.requestContentPublisher()); + } + return CompletableFuture.completedFuture(null); } @@ -122,7 +130,28 @@ public void stubResponses(HttpExecuteResponse... responses) { this.responseIndex.set(0); } - private class ResponsePublisher implements SdkHttpContentPublisher { + /** + * Enable capturing the streaming payload by setting the length of the AsyncRequestBody. + */ + public void setAsyncRequestBodyLength(int asyncRequestBodyLength) { + this.asyncRequestBodyLength = asyncRequestBodyLength; + } + + private void captureStreamingPayload(SdkHttpContentPublisher publisher) { + ByteBuffer byteBuffer = ByteBuffer.allocate(asyncRequestBodyLength); + Subscriber subscriber = new CapturingSubscriber(byteBuffer); + publisher.subscribe(subscriber); + streamingPayload = byteBuffer.array(); + } + + /** + * Returns the streaming payload byte array, if the asyncRequestBodyLength was set correctly. Otherwise, returns null. + */ + public byte[] getStreamingPayload() { + return streamingPayload.clone(); + } + + private final class ResponsePublisher implements SdkHttpContentPublisher { private final byte[] content; private final int index; @@ -165,4 +194,35 @@ public void cancel() { }); } } + + private static class CapturingSubscriber implements Subscriber { + private ByteBuffer byteBuffer; + private CountDownLatch done = new CountDownLatch(1); + + CapturingSubscriber(ByteBuffer byteBuffer) { + this.byteBuffer = byteBuffer; + } + + @Override + public void onSubscribe(Subscription subscription) { + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(ByteBuffer buffer) { + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + byteBuffer.put(bytes); + } + + @Override + public void onError(Throwable t) { + done.countDown(); + } + + @Override + public void onComplete() { + done.countDown(); + } + } }