diff --git a/sdk/core/azure-core-http-jdk-httpclient/src/main/java/com/azure/core/http/jdk/httpclient/BufferedJdkHttpResponse.java b/sdk/core/azure-core-http-jdk-httpclient/src/main/java/com/azure/core/http/jdk/httpclient/BufferedJdkHttpResponse.java index ab2abb3e8c54..5ba0d7827966 100644 --- a/sdk/core/azure-core-http-jdk-httpclient/src/main/java/com/azure/core/http/jdk/httpclient/BufferedJdkHttpResponse.java +++ b/sdk/core/azure-core-http-jdk-httpclient/src/main/java/com/azure/core/http/jdk/httpclient/BufferedJdkHttpResponse.java @@ -5,6 +5,7 @@ import com.azure.core.http.HttpHeaders; import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -30,4 +31,9 @@ public Flux getBody() { public Mono getBodyAsByteArray() { return Mono.defer(() -> Mono.just(body)); } + + @Override + public HttpResponse buffer() { + return this; // This response is already buffered. + } } diff --git a/sdk/core/azure-core-http-jdk-httpclient/src/main/java/com/azure/core/http/jdk/httpclient/JdkAsyncHttpClient.java b/sdk/core/azure-core-http-jdk-httpclient/src/main/java/com/azure/core/http/jdk/httpclient/JdkAsyncHttpClient.java index 2d5a4100d57a..2a68332625c2 100644 --- a/sdk/core/azure-core-http-jdk-httpclient/src/main/java/com/azure/core/http/jdk/httpclient/JdkAsyncHttpClient.java +++ b/sdk/core/azure-core-http-jdk-httpclient/src/main/java/com/azure/core/http/jdk/httpclient/JdkAsyncHttpClient.java @@ -66,9 +66,9 @@ public Mono send(HttpRequest request, Context context) { int statusCode = innerResponse.statusCode(); HttpHeaders headers = fromJdkHttpHeaders(innerResponse.headers()); - return FluxUtil.collectBytesInByteBufferStream(JdkFlowAdapter + return FluxUtil.collectBytesFromNetworkResponse(JdkFlowAdapter .flowPublisherToFlux(innerResponse.body()) - .flatMapSequential(Flux::fromIterable)) + .flatMapSequential(Flux::fromIterable), headers) .map(bytes -> new BufferedJdkHttpResponse(request, statusCode, headers, bytes)); } else { return Mono.just(new JdkHttpResponse(request, innerResponse)); diff --git a/sdk/core/azure-core-http-jdk-httpclient/src/main/java/com/azure/core/http/jdk/httpclient/JdkHttpResponse.java b/sdk/core/azure-core-http-jdk-httpclient/src/main/java/com/azure/core/http/jdk/httpclient/JdkHttpResponse.java index 49de1f564ef8..76d4a091d922 100644 --- a/sdk/core/azure-core-http-jdk-httpclient/src/main/java/com/azure/core/http/jdk/httpclient/JdkHttpResponse.java +++ b/sdk/core/azure-core-http-jdk-httpclient/src/main/java/com/azure/core/http/jdk/httpclient/JdkHttpResponse.java @@ -33,7 +33,7 @@ public Flux getBody() { @Override public Mono getBodyAsByteArray() { - return FluxUtil.collectBytesInByteBufferStream(getBody()); + return FluxUtil.collectBytesFromNetworkResponse(getBody(), getHeaders()); } @Override diff --git a/sdk/core/azure-core-http-netty/src/main/java/com/azure/core/http/netty/NettyAsyncHttpClient.java b/sdk/core/azure-core-http-netty/src/main/java/com/azure/core/http/netty/NettyAsyncHttpClient.java index db5378f78af8..e0fea44bf661 100644 --- a/sdk/core/azure-core-http-netty/src/main/java/com/azure/core/http/netty/NettyAsyncHttpClient.java +++ b/sdk/core/azure-core-http-netty/src/main/java/com/azure/core/http/netty/NettyAsyncHttpClient.java @@ -8,8 +8,9 @@ import com.azure.core.http.HttpRequest; import com.azure.core.http.HttpResponse; import com.azure.core.http.ProxyOptions; -import com.azure.core.http.netty.implementation.NettyAsyncHttpResponse; import com.azure.core.http.netty.implementation.NettyAsyncHttpBufferedResponse; +import com.azure.core.http.netty.implementation.NettyAsyncHttpResponse; +import com.azure.core.http.netty.implementation.NettyToAzureCoreHttpHeadersWrapper; import com.azure.core.util.Context; import com.azure.core.util.FluxUtil; import io.netty.buffer.ByteBuf; @@ -144,7 +145,8 @@ private static BiFunction body = reactorNettyConnection.inbound().receive().asByteBuffer() .doFinally(ignored -> closeConnection(reactorNettyConnection)); - return FluxUtil.collectBytesInByteBufferStream(body) + return FluxUtil.collectBytesFromNetworkResponse(body, + new NettyToAzureCoreHttpHeadersWrapper(reactorNettyResponse.responseHeaders())) .map(bytes -> new NettyAsyncHttpBufferedResponse(reactorNettyResponse, restRequest, bytes)); } else { diff --git a/sdk/core/azure-core-http-netty/src/main/java/com/azure/core/http/netty/implementation/NettyAsyncHttpBufferedResponse.java b/sdk/core/azure-core-http-netty/src/main/java/com/azure/core/http/netty/implementation/NettyAsyncHttpBufferedResponse.java index fe8fd053ff84..102b40e02e8c 100644 --- a/sdk/core/azure-core-http-netty/src/main/java/com/azure/core/http/netty/implementation/NettyAsyncHttpBufferedResponse.java +++ b/sdk/core/azure-core-http-netty/src/main/java/com/azure/core/http/netty/implementation/NettyAsyncHttpBufferedResponse.java @@ -4,6 +4,7 @@ package com.azure.core.http.netty.implementation; import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; import com.azure.core.util.CoreUtils; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -42,4 +43,9 @@ public Mono getBodyAsString() { public Mono getBodyAsString(Charset charset) { return Mono.defer(() -> Mono.just(new String(body, charset))); } + + @Override + public HttpResponse buffer() { + return this; // This response is already buffered. + } } diff --git a/sdk/core/azure-core-http-okhttp/src/main/java/com/azure/core/http/okhttp/implementation/OkHttpAsyncBufferedResponse.java b/sdk/core/azure-core-http-okhttp/src/main/java/com/azure/core/http/okhttp/implementation/OkHttpAsyncBufferedResponse.java index f28c07916069..9f3d4096a2e5 100644 --- a/sdk/core/azure-core-http-okhttp/src/main/java/com/azure/core/http/okhttp/implementation/OkHttpAsyncBufferedResponse.java +++ b/sdk/core/azure-core-http-okhttp/src/main/java/com/azure/core/http/okhttp/implementation/OkHttpAsyncBufferedResponse.java @@ -4,6 +4,7 @@ package com.azure.core.http.okhttp.implementation; import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; import okhttp3.Response; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -30,4 +31,9 @@ public Flux getBody() { public Mono getBodyAsByteArray() { return Mono.defer(() -> Mono.just(body)); } + + @Override + public HttpResponse buffer() { + return this; // This response is already buffered. + } } diff --git a/sdk/core/azure-core-test/src/main/java/com/azure/core/test/http/MockHttpResponse.java b/sdk/core/azure-core-test/src/main/java/com/azure/core/test/http/MockHttpResponse.java index 172d7491574e..a85e3e95305c 100644 --- a/sdk/core/azure-core-test/src/main/java/com/azure/core/test/http/MockHttpResponse.java +++ b/sdk/core/azure-core-test/src/main/java/com/azure/core/test/http/MockHttpResponse.java @@ -197,4 +197,9 @@ public MockHttpResponse addHeader(String name, String value) { headers.set(name, value); return this; } + + @Override + public HttpResponse buffer() { + return this; // This response is already buffered. + } } diff --git a/sdk/core/azure-core/pom.xml b/sdk/core/azure-core/pom.xml index 86d3b907c4f6..9d2f203509cb 100644 --- a/sdk/core/azure-core/pom.xml +++ b/sdk/core/azure-core/pom.xml @@ -239,6 +239,7 @@ --add-opens com.azure.core/com.azure.core.http=ALL-UNNAMED --add-opens com.azure.core/com.azure.core.http.policy=ALL-UNNAMED --add-opens com.azure.core/com.azure.core.http.rest=ALL-UNNAMED + --add-opens com.azure.core/com.azure.core.implementation=ALL-UNNAMED --add-opens com.azure.core/com.azure.core.implementation.entities=com.fasterxml.jackson.databind --add-opens com.azure.core/com.azure.core.implementation.entities=ALL-UNNAMED --add-opens com.azure.core/com.azure.core.implementation.http=ALL-UNNAMED diff --git a/sdk/core/azure-core/src/main/java/com/azure/core/http/rest/ResponseConstructorsCache.java b/sdk/core/azure-core/src/main/java/com/azure/core/http/rest/ResponseConstructorsCache.java index 57b8ef512572..9a99221566cf 100644 --- a/sdk/core/azure-core/src/main/java/com/azure/core/http/rest/ResponseConstructorsCache.java +++ b/sdk/core/azure-core/src/main/java/com/azure/core/http/rest/ResponseConstructorsCache.java @@ -79,8 +79,8 @@ private Constructor> locateResponseConstructor(Class re * @return an instance of a {@link Response} implementation */ Mono> invoke(final Constructor> constructor, - final HttpResponseDecoder.HttpDecodedResponse decodedResponse, - final Object bodyAsObject) { + final HttpResponseDecoder.HttpDecodedResponse decodedResponse, + final Object bodyAsObject) { final HttpResponse httpResponse = decodedResponse.getSourceResponse(); final HttpRequest httpRequest = httpResponse.getRequest(); final int responseStatusCode = httpResponse.getStatusCode(); diff --git a/sdk/core/azure-core/src/main/java/com/azure/core/implementation/ByteBufferCollector.java b/sdk/core/azure-core/src/main/java/com/azure/core/implementation/ByteBufferCollector.java new file mode 100644 index 000000000000..5fbd59c74188 --- /dev/null +++ b/sdk/core/azure-core/src/main/java/com/azure/core/implementation/ByteBufferCollector.java @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.core.implementation; + +import com.azure.core.util.logging.ClientLogger; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.util.Arrays; + +/** + * This class offers functionality similar to {@link ByteArrayOutputStream} but instead of consuming byte arrays it + * consumes ByteBuffers. This class is optimized to reduce the number of memory copies by directly writing a passed + * ByteBuffers data directly into its backing byte array, this differs from handling for {@link ByteArrayOutputStream} + * where ByteBuffer data may need to be first copied into a temporary buffer resulting in an extra memory copy. + */ +public final class ByteBufferCollector { + /* + * Start with a default size of 1 KB as this is small enough to be performant while covering most small response + * sizes. + */ + private static final int DEFAULT_INITIAL_SIZE = 1024; + + private static final String INVALID_INITIAL_SIZE = "'initialSize' cannot be equal to or less than 0."; + private static final String REQUESTED_BUFFER_INVALID = "Required capacity is greater than Integer.MAX_VALUE."; + + private final ClientLogger logger = new ClientLogger(ByteBufferCollector.class); + + private byte[] buffer; + private int position; + + /** + * Constructs a new ByteBufferCollector instance with a default sized backing array. + */ + public ByteBufferCollector() { + this(DEFAULT_INITIAL_SIZE); + } + + /** + * Constructs a new ByteBufferCollector instance with a specified initial size. + * + * @param initialSize The initial size for the backing array. + * @throws IllegalArgumentException If {@code initialSize} is equal to or less than {@code 0}. + */ + public ByteBufferCollector(int initialSize) { + if (initialSize <= 0) { + throw logger.logExceptionAsError(new IllegalArgumentException(INVALID_INITIAL_SIZE)); + } + + this.buffer = new byte[initialSize]; + this.position = 0; + } + + /** + * Writes a ByteBuffers content into the backing array. + * + * @param byteBuffer The ByteBuffer to concatenate into the collector. + * @throws IllegalStateException If the size of the backing array would be larger than {@link Integer#MAX_VALUE} + * when the passed buffer is written. + */ + public synchronized void write(ByteBuffer byteBuffer) { + // Null buffer. + if (byteBuffer == null) { + return; + } + + int remaining = byteBuffer.remaining(); + + // Nothing to write. + if (remaining == 0) { + return; + } + + ensureCapacity(remaining); + byteBuffer.get(buffer, position, remaining); + position += remaining; + } + + /** + * Creates a copy of the backing array resized to the number of bytes written into the collector. + * + * @return A copy of the backing array. + */ + public synchronized byte[] toByteArray() { + return Arrays.copyOf(buffer, position); + } + + /* + * This method ensures that the backing buffer has sufficient space to write the data from the passed ByteBuffer. + */ + private void ensureCapacity(int byteBufferRemaining) throws OutOfMemoryError { + int currentCapacity = buffer.length; + int requiredCapacity = position + byteBufferRemaining; + + /* + * This validates that adding the current capacity and ByteBuffer remaining doesn't result in an integer + * overflow response by checking that the result uses the same sign as both of the addition arguments. + */ + if (((position ^ requiredCapacity) & (byteBufferRemaining ^ requiredCapacity)) < 0) { + throw logger.logExceptionAsError(new IllegalStateException(REQUESTED_BUFFER_INVALID)); + } + + // Buffer is already large enough to accept the data being written. + if (currentCapacity >= requiredCapacity) { + return; + } + + // Propose a new capacity that is double the size of the current capacity. + int proposedNewCapacity = currentCapacity << 1; + + // If the proposed capacity is less than the required capacity use the required capacity. + // Subtraction is used instead of a direct comparison as the bit shift could overflow into a negative int. + if ((proposedNewCapacity - requiredCapacity) < 0) { + proposedNewCapacity = requiredCapacity; + } + + // If the proposed capacity doubling overflowed integer use a slightly smaller size than max value. + if (proposedNewCapacity < 0) { + proposedNewCapacity = Integer.MAX_VALUE - 8; + } + + buffer = Arrays.copyOf(buffer, proposedNewCapacity); + } +} diff --git a/sdk/core/azure-core/src/main/java/com/azure/core/implementation/http/BufferedHttpResponse.java b/sdk/core/azure-core/src/main/java/com/azure/core/implementation/http/BufferedHttpResponse.java index aca006f5a7a2..3236362ea337 100644 --- a/sdk/core/azure-core/src/main/java/com/azure/core/implementation/http/BufferedHttpResponse.java +++ b/sdk/core/azure-core/src/main/java/com/azure/core/implementation/http/BufferedHttpResponse.java @@ -28,7 +28,8 @@ public final class BufferedHttpResponse extends HttpResponse { public BufferedHttpResponse(HttpResponse innerHttpResponse) { super(innerHttpResponse.getRequest()); this.innerHttpResponse = innerHttpResponse; - this.cachedBody = FluxUtil.collectBytesInByteBufferStream(innerHttpResponse.getBody()) + this.cachedBody = FluxUtil.collectBytesFromNetworkResponse(innerHttpResponse.getBody(), + innerHttpResponse.getHeaders()) .map(ByteBuffer::wrap) .flux() .cache() diff --git a/sdk/core/azure-core/src/main/java/com/azure/core/implementation/serializer/HttpResponseHeaderDecoder.java b/sdk/core/azure-core/src/main/java/com/azure/core/implementation/serializer/HttpResponseHeaderDecoder.java index b0f44d481d5d..6b211343bca1 100644 --- a/sdk/core/azure-core/src/main/java/com/azure/core/implementation/serializer/HttpResponseHeaderDecoder.java +++ b/sdk/core/azure-core/src/main/java/com/azure/core/implementation/serializer/HttpResponseHeaderDecoder.java @@ -34,8 +34,8 @@ static Mono decode(HttpResponse httpResponse, SerializerAdapter serializ return Mono.empty(); } else { return Mono.fromCallable(() -> serializer.deserialize(httpResponse.getHeaders(), headerType)) - .onErrorResume(IOException.class, e -> Mono.error(new HttpResponseException( - "HTTP response has malformed headers", httpResponse, e))); + .onErrorMap(IOException.class, e -> new HttpResponseException("HTTP response has malformed headers", + httpResponse, e)); } } } diff --git a/sdk/core/azure-core/src/main/java/com/azure/core/util/FluxUtil.java b/sdk/core/azure-core/src/main/java/com/azure/core/util/FluxUtil.java index 62dda42aad99..2aa9627f33b2 100644 --- a/sdk/core/azure-core/src/main/java/com/azure/core/util/FluxUtil.java +++ b/sdk/core/azure-core/src/main/java/com/azure/core/util/FluxUtil.java @@ -3,8 +3,10 @@ package com.azure.core.util; +import com.azure.core.http.HttpHeaders; import com.azure.core.http.rest.PagedFlux; import com.azure.core.http.rest.Response; +import com.azure.core.implementation.ByteBufferCollector; import com.azure.core.implementation.TypeUtil; import com.azure.core.util.logging.ClientLogger; import org.reactivestreams.Subscriber; @@ -15,7 +17,6 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Type; @@ -25,6 +26,7 @@ import java.util.Collections; import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.function.Function; @@ -34,6 +36,8 @@ * Utility type exposing methods to deal with {@link Flux}. */ public final class FluxUtil { + private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; + /** * Checks if a type is Flux<ByteBuffer>. * @@ -49,22 +53,68 @@ public static boolean isFluxByteBuffer(Type entityType) { } /** - * Collects ByteBuffer emitted by a Flux into a byte array. + * Collects ByteBuffers emitted by a Flux into a byte array. * * @param stream A stream which emits ByteBuffer instances. * @return A Mono which emits the concatenation of all the ByteBuffer instances given by the source Flux. + * @throws IllegalStateException If the combined size of the emitted ByteBuffers is greater than {@link + * Integer#MAX_VALUE}. */ public static Mono collectBytesInByteBufferStream(Flux stream) { - return stream - .collect(ByteArrayOutputStream::new, FluxUtil::accept) - .map(ByteArrayOutputStream::toByteArray); + return stream.collect(ByteBufferCollector::new, ByteBufferCollector::write) + .map(ByteBufferCollector::toByteArray); } - private static void accept(ByteArrayOutputStream byteOutputStream, ByteBuffer byteBuffer) { - try { - byteOutputStream.write(byteBufferToArray(byteBuffer)); - } catch (IOException e) { - throw new RuntimeException("Error occurred writing ByteBuffer to ByteArrayOutputStream.", e); + /** + * Collects ByteBuffers emitted by a Flux into a byte array. + *

+ * Unlike {@link #collectBytesInByteBufferStream(Flux)}, this method accepts a second parameter {@code sizeHint}. + * This size hint allows for optimizations when creating the initial buffer to reduce the number of times it needs + * to be resized while concatenating emitted ByteBuffers. + * + * @param stream A stream which emits ByteBuffer instances. + * @param sizeHint A hint about the expected stream size. + * @return A Mono which emits the concatenation of all the ByteBuffer instances given by the source Flux. + * @throws IllegalArgumentException If {@code sizeHint} is equal to or less than {@code 0}. + * @throws IllegalStateException If the combined size of the emitted ByteBuffers is greater than {@link + * Integer#MAX_VALUE}. + */ + public static Mono collectBytesInByteBufferStream(Flux stream, int sizeHint) { + return stream.collect(() -> new ByteBufferCollector(sizeHint), ByteBufferCollector::write) + .map(ByteBufferCollector::toByteArray); + } + + /** + * Collects ByteBuffers returned in a network response into a byte array. + *

+ * The {@code headers} are inspected for containing an {@code Content-Length} which determines if a size hinted + * collection, {@link #collectBytesInByteBufferStream(Flux, int)}, or default collection, + * {@link #collectBytesInByteBufferStream(Flux)}, will be used. + * + * @param stream A network response ByteBuffer stream. + * @param headers The HTTP headers of the response. + * @return A Mono which emits the collected network response ByteBuffers. + * @throws NullPointerException If {@code headers} is null. + * @throws IllegalStateException If the size of the network response is greater than {@link Integer#MAX_VALUE}. + */ + public static Mono collectBytesFromNetworkResponse(Flux stream, HttpHeaders headers) { + Objects.requireNonNull(headers, "'headers' cannot be null."); + + String contentLengthHeader = headers.getValue("Content-Length"); + + if (contentLengthHeader == null) { + return FluxUtil.collectBytesInByteBufferStream(stream); + } else { + try { + int contentLength = Integer.parseInt(contentLengthHeader); + if (contentLength > 0) { + return FluxUtil.collectBytesInByteBufferStream(stream, contentLength); + } else { + return Mono.just(EMPTY_BYTE_ARRAY); + } + } catch (NumberFormatException ex) { + return FluxUtil.collectBytesInByteBufferStream(stream); + } } } @@ -187,7 +237,7 @@ public static Mono withContext(Function> serviceCall, Map contextAttributes) { return Mono.subscriberContext() .map(context -> { - final Context[] azureContext = new Context[] { Context.NONE }; + final Context[] azureContext = new Context[]{Context.NONE}; if (!CoreUtils.isNullOrEmpty(contextAttributes)) { contextAttributes.forEach((key, value) -> azureContext[0] = azureContext[0].addData(key, value)); @@ -279,7 +329,7 @@ public static Flux fluxContext(Function> serviceCall) { * @return The azure context */ private static Context toAzureContext(reactor.util.context.Context context) { - final Context[] azureContext = new Context[] { Context.NONE }; + final Context[] azureContext = new Context[]{Context.NONE}; if (!context.isEmpty()) { context.stream().forEach(entry -> diff --git a/sdk/core/azure-core/src/test/java/com/azure/core/implementation/ByteBufferCollectorTests.java b/sdk/core/azure-core/src/test/java/com/azure/core/implementation/ByteBufferCollectorTests.java new file mode 100644 index 000000000000..416c0265d956 --- /dev/null +++ b/sdk/core/azure-core/src/test/java/com/azure/core/implementation/ByteBufferCollectorTests.java @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.core.implementation; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** + * Tests {@link ByteBufferCollector}. + */ +public class ByteBufferCollectorTests { + @Test + public void throwsOnZeroInitialSize() { + assertThrows(IllegalArgumentException.class, () -> new ByteBufferCollector(0)); + } + + @Test + public void throwsOnNegativeInitialSize() { + assertThrows(IllegalArgumentException.class, () -> new ByteBufferCollector(-1)); + } + + @Test + public void throwsIllegalStateExceptionOnBufferRequirementTooLarge() { + /* + * This assumption validates that the JVM running this test has a maximum heap size large enough for the test + * to run without triggering an OutOfMemoryError. + */ + assumeTrue(Runtime.getRuntime().maxMemory() > (Integer.MAX_VALUE * 1.5), + "JVM doesn't have the requisite max heap size to support running this test."); + + ByteBuffer buffer = ByteBuffer.allocate((Integer.MAX_VALUE / 2) + 1); + + ByteBufferCollector collector = new ByteBufferCollector(); + collector.write(buffer.duplicate()); + + assertThrows(IllegalStateException.class, () -> collector.write(buffer.duplicate())); + } + + @ParameterizedTest + @MethodSource("combineBuffersSupplier") + public void combineBuffers(List buffers, byte[] expected) { + ByteBufferCollector collector = new ByteBufferCollector(); + + buffers.forEach(collector::write); + + assertArrayEquals(expected, collector.toByteArray()); + } + + private static Stream combineBuffersSupplier() { + byte[] helloWorldBytes = "Hello world!".getBytes(StandardCharsets.UTF_8); + ByteBuffer helloBuffer = ByteBuffer.wrap("Hello".getBytes(StandardCharsets.UTF_8)); + ByteBuffer worldBuffer = ByteBuffer.wrap(" world!".getBytes(StandardCharsets.UTF_8)); + + int helloWorldLength = helloWorldBytes.length; + byte[] manyHelloWorldsBytes = new byte[helloWorldLength * 100]; + List manyHelloWorldsByteBuffers = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + System.arraycopy(helloWorldBytes, 0, manyHelloWorldsBytes, i * helloWorldLength, helloWorldLength); + manyHelloWorldsByteBuffers.add(helloBuffer.duplicate()); + manyHelloWorldsByteBuffers.add(worldBuffer.duplicate()); + } + + return Stream.of( + // All buffers are null. + Arguments.of(Arrays.asList(null, null), new byte[0]), + + // All buffers are empty. + Arguments.of(Arrays.asList(ByteBuffer.allocate(0), ByteBuffer.allocate(0)), new byte[0]), + + // Hello world buffers. + Arguments.of(Arrays.asList(helloBuffer.duplicate(), worldBuffer.duplicate()), helloWorldBytes), + + // Many hello world buffers. + Arguments.of(manyHelloWorldsByteBuffers, manyHelloWorldsBytes) + ); + } +}