diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java index f48a3143fd016..a309877e9aa83 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java @@ -53,6 +53,7 @@ import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils; import org.elasticsearch.http.netty4.internal.HttpValidator; +import org.elasticsearch.rest.ChunkedZipResponse; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.netty4.AcceptChannelHandler; @@ -382,7 +383,16 @@ protected boolean isContentAlwaysEmpty(HttpResponse msg) { }) .addLast("aggregator", aggregator); if (handlingSettings.compression()) { - ch.pipeline().addLast("encoder_compress", new HttpContentCompressor(handlingSettings.compressionLevel())); + ch.pipeline().addLast("encoder_compress", new HttpContentCompressor(handlingSettings.compressionLevel()) { + @Override + protected Result beginEncode(HttpResponse httpResponse, String acceptEncoding) throws Exception { + if (ChunkedZipResponse.ZIP_CONTENT_TYPE.equals(httpResponse.headers().get("content-type"))) { + return null; + } else { + return super.beginEncode(httpResponse, acceptEncoding); + } + } + }); } ch.pipeline() .addLast( diff --git a/server/src/internalClusterTest/java/org/elasticsearch/rest/ChunkedZipResponseIT.java b/server/src/internalClusterTest/java/org/elasticsearch/rest/ChunkedZipResponseIT.java new file mode 100644 index 0000000000000..0fbdcf7b59a7c --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/rest/ChunkedZipResponseIT.java @@ -0,0 +1,502 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.rest; + +import org.apache.http.ConnectionClosedException; +import org.apache.http.HttpResponse; +import org.apache.http.MalformedChunkCodingException; +import org.apache.http.nio.ContentDecoder; +import org.apache.http.nio.IOControl; +import org.apache.http.nio.protocol.HttpAsyncResponseConsumer; +import org.apache.http.protocol.HttpContext; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; +import org.elasticsearch.action.support.ActionTestUtils; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.recycler.Recycler; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.IndexScopedSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.settings.SettingsFilter; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.common.util.Maps; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThrottledIterator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.plugins.ActionPlugin; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.PluginsService; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Predicate; +import java.util.function.Supplier; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import java.util.zip.ZipEntry; +import java.util.zip.ZipInputStream; + +import static org.elasticsearch.rest.ChunkedZipResponse.ZIP_CONTENT_TYPE; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.startsWith; + +@ESIntegTestCase.ClusterScope(numDataNodes = 1) +public class ChunkedZipResponseIT extends ESIntegTestCase { + + @Override + protected boolean addMockHttpTransport() { + return false; + } + + @Override + protected Collection> nodePlugins() { + return CollectionUtils.appendToCopyNoNullElements(super.nodePlugins(), RandomZipResponsePlugin.class); + } + + public static class RandomZipResponsePlugin extends Plugin implements ActionPlugin { + + public static final String ROUTE = "/_random_zip_response"; + public static final String RESPONSE_FILENAME = "test-response"; + + public static final String INFINITE_ROUTE = "/_infinite_zip_response"; + public static final String GET_NEXT_PART_COUNT_DOWN_PARAM = "getNextPartCountDown"; + + public final AtomicReference responseRef = new AtomicReference<>(); + + public record EntryPart(List chunks) { + public EntryPart { + Objects.requireNonNull(chunks); + } + } + + public record EntryBody(List parts) { + public EntryBody { + Objects.requireNonNull(parts); + } + } + + public record Response(Map entries, CountDownLatch completedLatch) {} + + @Override + public Collection getRestHandlers( + Settings settings, + NamedWriteableRegistry namedWriteableRegistry, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster, + Predicate clusterSupportsFeature + ) { + return List.of(new RestHandler() { + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.GET, ROUTE)); + } + + @Override + public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) { + final var response = new Response(new HashMap<>(), new CountDownLatch(1)); + final var maxSize = between(1, ByteSizeUnit.MB.toIntBytes(1)); + final var entryCount = between(0, ByteSizeUnit.MB.toIntBytes(10) / maxSize); // limit total size to 10MiB + for (int i = 0; i < entryCount; i++) { + response.entries().put(randomIdentifier(), randomContent(between(1, 10), maxSize)); + } + assertTrue(responseRef.compareAndSet(null, response)); + handleZipRestRequest( + channel, + client.threadPool(), + response.completedLatch(), + () -> {}, + response.entries().entrySet().iterator() + ); + } + }, new RestHandler() { + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.GET, INFINITE_ROUTE)); + } + + @Override + public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) { + final var response = new Response(null, new CountDownLatch(1)); + assertTrue(responseRef.compareAndSet(null, response)); + final var getNextPartCountDown = request.paramAsInt(GET_NEXT_PART_COUNT_DOWN_PARAM, -1); + final Runnable onGetNextPart; + final Supplier entryBodySupplier; + if (getNextPartCountDown <= 1) { + onGetNextPart = () -> {}; + entryBodySupplier = () -> randomContent(between(1, 10), ByteSizeUnit.MB.toIntBytes(1)); + } else { + final AtomicInteger remaining = new AtomicInteger(getNextPartCountDown); + entryBodySupplier = () -> randomContent(between(2, 10), ByteSizeUnit.KB.toIntBytes(1)); + if (randomBoolean()) { + onGetNextPart = () -> { + final var newRemaining = remaining.decrementAndGet(); + assertThat(newRemaining, greaterThanOrEqualTo(0)); + if (newRemaining <= 0) { + throw new ElasticsearchException("simulated failure"); + } + }; + } else { + onGetNextPart = () -> { + if (remaining.decrementAndGet() == 0) { + request.getHttpChannel().close(); + } + }; + } + } + handleZipRestRequest(channel, client.threadPool(), response.completedLatch(), onGetNextPart, new Iterator<>() { + + private long id; + + // carry on yielding content even after the channel closes + private final Semaphore trailingContentPermits = new Semaphore(between(0, 20)); + + @Override + public boolean hasNext() { + return request.getHttpChannel().isOpen() || trailingContentPermits.tryAcquire(); + } + + @Override + public Map.Entry next() { + return new Map.Entry<>() { + private final String key = Long.toString(id++); + private final EntryBody content = entryBodySupplier.get(); + + @Override + public String getKey() { + return key; + } + + @Override + public EntryBody getValue() { + return content; + } + + @Override + public EntryBody setValue(EntryBody value) { + return fail(null, "must not setValue"); + } + }; + } + }); + } + }); + } + + private static EntryBody randomContent(int partCount, int maxSize) { + if (randomBoolean()) { + return null; + } + + final var maxPartSize = maxSize / partCount; + return new EntryBody(randomList(partCount, partCount, () -> { + final var chunkCount = between(1, 10); + return randomEntryPart(chunkCount, maxPartSize / chunkCount); + })); + } + + private static EntryPart randomEntryPart(int chunkCount, int maxChunkSize) { + final var chunks = randomList(chunkCount, chunkCount, () -> randomBytesReference(between(0, maxChunkSize))); + Collections.shuffle(chunks, random()); + return new EntryPart(chunks); + } + + private static void handleZipRestRequest( + RestChannel channel, + ThreadPool threadPool, + CountDownLatch completionLatch, + Runnable onGetNextPart, + Iterator> entryIterator + ) { + try (var refs = new RefCountingRunnable(completionLatch::countDown)) { + final var chunkedZipResponse = new ChunkedZipResponse(RESPONSE_FILENAME, channel, refs.acquire()); + ThrottledIterator.run( + entryIterator, + (ref, entry) -> randomFrom(EsExecutors.DIRECT_EXECUTOR_SERVICE, threadPool.generic()).execute( + ActionRunnable.supply( + chunkedZipResponse.newEntryListener(entry.getKey(), Releasables.wrap(ref, refs.acquire())), + () -> entry.getValue() == null && randomBoolean() // randomBoolean() to allow some null entries to fail with NPE + ? null + : new TestBytesReferenceBodyPart( + entry.getKey(), + threadPool, + entry.getValue().parts().iterator(), + refs, + onGetNextPart + ) + ) + ), + between(1, 10), + () -> {}, + Releasables.wrap(refs.acquire(), chunkedZipResponse)::close + ); + } + } + } + + private static class TestBytesReferenceBodyPart implements ChunkedRestResponseBodyPart { + + private final String name; + private final ThreadPool threadPool; + private final Iterator chunksIterator; + private final Iterator partsIterator; + private final RefCountingRunnable refs; + private final Runnable onGetNextPart; + + TestBytesReferenceBodyPart( + String name, + ThreadPool threadPool, + Iterator partsIterator, + RefCountingRunnable refs, + Runnable onGetNextPart + ) { + this.onGetNextPart = onGetNextPart; + assert partsIterator.hasNext(); + this.name = name; + this.threadPool = threadPool; + this.partsIterator = partsIterator; + this.chunksIterator = partsIterator.next().chunks().iterator(); + this.refs = refs; + } + + @Override + public boolean isPartComplete() { + return chunksIterator.hasNext() == false; + } + + @Override + public boolean isLastPart() { + return partsIterator.hasNext() == false; + } + + @Override + public void getNextPart(ActionListener listener) { + threadPool.generic().execute(ActionRunnable.supply(listener, () -> { + onGetNextPart.run(); + return new TestBytesReferenceBodyPart(name, threadPool, partsIterator, refs, onGetNextPart); + })); + } + + @Override + public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) { + assert chunksIterator.hasNext(); + return new ReleasableBytesReference(chunksIterator.next(), refs.acquire()); + } + + @Override + public String getResponseContentTypeString() { + return "application/binary"; + } + } + + public void testRandomZipResponse() throws IOException { + final var request = new Request("GET", RandomZipResponsePlugin.ROUTE); + if (randomBoolean()) { + request.setOptions( + RequestOptions.DEFAULT.toBuilder() + .addHeader("accept-encoding", String.join(", ", randomSubsetOf(List.of("deflate", "gzip", "zstd", "br")))) + ); + } + final var response = getRestClient().performRequest(request); + assertEquals(ZIP_CONTENT_TYPE, response.getHeader("Content-Type")); + assertNull(response.getHeader("content-encoding")); // zip file is already compressed + assertEquals( + "attachment; filename=\"" + RandomZipResponsePlugin.RESPONSE_FILENAME + ".zip\"", + response.getHeader("Content-Disposition") + ); + final var pathPrefix = RandomZipResponsePlugin.RESPONSE_FILENAME + "/"; + + final var actualEntries = new HashMap(); + final var copyBuffer = new byte[PageCacheRecycler.BYTE_PAGE_SIZE]; + + try (var zipStream = new ZipInputStream(response.getEntity().getContent())) { + ZipEntry zipEntry; + while ((zipEntry = zipStream.getNextEntry()) != null) { + assertThat(zipEntry.getName(), startsWith(pathPrefix)); + final var name = zipEntry.getName().substring(pathPrefix.length()); + try (var bytesStream = new BytesStreamOutput()) { + while (true) { + final var readLength = zipStream.read(copyBuffer, 0, copyBuffer.length); + if (readLength < 0) { + break; + } + bytesStream.write(copyBuffer, 0, readLength); + } + actualEntries.put(name, bytesStream.bytes()); + } + } + } + + assertEquals(getExpectedEntries(), actualEntries); + } + + public void testAbort() throws IOException { + final var request = new Request("GET", RandomZipResponsePlugin.INFINITE_ROUTE); + final var responseStarted = new CountDownLatch(1); + final var bodyConsumed = new CountDownLatch(1); + request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> new HttpAsyncResponseConsumer<>() { + + final ByteBuffer readBuffer = ByteBuffer.allocate(ByteSizeUnit.KB.toIntBytes(4)); + int bytesToConsume = ByteSizeUnit.MB.toIntBytes(1); + + @Override + public void responseReceived(HttpResponse response) { + assertEquals("application/zip", response.getHeaders("Content-Type")[0].getValue()); + final var contentDispositionHeader = response.getHeaders("Content-Disposition")[0].getElements()[0]; + assertEquals("attachment", contentDispositionHeader.getName()); + assertEquals( + RandomZipResponsePlugin.RESPONSE_FILENAME + ".zip", + contentDispositionHeader.getParameterByName("filename").getValue() + ); + responseStarted.countDown(); + } + + @Override + public void consumeContent(ContentDecoder decoder, IOControl ioControl) throws IOException { + readBuffer.clear(); + final var bytesRead = decoder.read(readBuffer); + if (bytesRead > 0) { + bytesToConsume -= bytesRead; + } + + if (bytesToConsume <= 0) { + bodyConsumed.countDown(); + ioControl.shutdown(); + } + } + + @Override + public void responseCompleted(HttpContext context) {} + + @Override + public void failed(Exception ex) {} + + @Override + public Exception getException() { + return null; + } + + @Override + public HttpResponse getResult() { + return null; + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public void close() {} + + @Override + public boolean cancel() { + return false; + } + })); + + try (var restClient = createRestClient(internalCluster().getRandomNodeName())) { + // one-node REST client to avoid retries + expectThrows(ConnectionClosedException.class, () -> restClient.performRequest(request)); + } + safeAwait(responseStarted); + safeAwait(bodyConsumed); + assertNull(getExpectedEntries()); // mainly just checking that all refs are released + } + + public void testGetNextPartFailure() throws IOException { + final var request = new Request("GET", RandomZipResponsePlugin.INFINITE_ROUTE); + request.addParameter(RandomZipResponsePlugin.GET_NEXT_PART_COUNT_DOWN_PARAM, Integer.toString(between(1, 100))); + + try (var restClient = createRestClient(internalCluster().getRandomNodeName())) { + // one-node REST client to avoid retries + assertThat( + safeAwaitFailure( + Response.class, + l -> restClient.performRequestAsync(request, ActionTestUtils.wrapAsRestResponseListener(l)) + ), + anyOf(instanceOf(ConnectionClosedException.class), instanceOf(MalformedChunkCodingException.class)) + ); + } + assertNull(getExpectedEntries()); // mainly just checking that all refs are released + } + + private static Map getExpectedEntries() { + final List> nodeResponses = StreamSupport + // concatenate all the chunks in all the entries + .stream(internalCluster().getInstances(PluginsService.class).spliterator(), false) + .flatMap(p -> p.filterPlugins(RandomZipResponsePlugin.class)) + .flatMap(p -> { + final var maybeResponse = p.responseRef.getAndSet(null); + if (maybeResponse == null) { + return Stream.of(); + } else { + safeAwait(maybeResponse.completedLatch()); // ensures that all refs have been released + if (maybeResponse.entries() == null) { + return Stream.of((Map) null); + } else { + final var expectedEntries = Maps.newMapWithExpectedSize(maybeResponse.entries().size()); + maybeResponse.entries().forEach((entryName, entryBody) -> { + if (entryBody != null) { + try (var bytesStreamOutput = new BytesStreamOutput()) { + for (final var part : entryBody.parts()) { + for (final var chunk : part.chunks()) { + chunk.writeTo(bytesStreamOutput); + } + } + expectedEntries.put(entryName, bytesStreamOutput.bytes()); + } catch (IOException e) { + throw new AssertionError(e); + } + } + }); + return Stream.of(expectedEntries); + } + } + }) + .toList(); + assertThat(nodeResponses, hasSize(1)); + return nodeResponses.get(0); + } +} diff --git a/server/src/main/java/org/elasticsearch/rest/ChunkedZipResponse.java b/server/src/main/java/org/elasticsearch/rest/ChunkedZipResponse.java new file mode 100644 index 0000000000000..2d4ada2d3341a --- /dev/null +++ b/server/src/main/java/org/elasticsearch/rest/ChunkedZipResponse.java @@ -0,0 +1,553 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.rest; + +import org.apache.lucene.store.AlreadyClosedException; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.BytesStream; +import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; +import org.elasticsearch.common.recycler.Recycler; +import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.transport.Transports; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Queue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; + +/** + * A REST response with {@code Content-type: application/zip} to which the caller can write entries in an asynchronous and streaming + * fashion. + *

+ * Callers obtain a listener for individual entries using {@link #newEntryListener} and complete these listeners to submit the corresponding + * entries for transmission. Internally, the output entries are held in a queue in the order in which the entry listeners are completed. + * If the queue becomes empty then the response transmission is paused until the next entry becomes available. + *

+ * The internal queue is unbounded. It is the caller's responsibility to ensure that the response does not consume an excess of resources + * while it's being sent. + *

+ * The caller must eventually call {@link ChunkedZipResponse#close} to finish the transmission of the response. + *

+ * Note that individual entries can also pause themselves mid-transmission, since listeners returned by {@link #newEntryListener} accept a + * pauseable {@link ChunkedRestResponseBodyPart}. Zip files do not have any mechanism which supports the multiplexing of outputs, so if the + * entry at the head of the queue is paused then that will hold up the transmission of all subsequent entries too. + */ +public final class ChunkedZipResponse implements Releasable { + + public static final String ZIP_CONTENT_TYPE = "application/zip"; + + /** + * The underlying stream that collects the raw bytes to be transmitted. Mutable, because we collect the contents of each chunk in a + * distinct stream that is held in this field while that chunk is under construction. + */ + @Nullable // if there's no chunk under construction + private BytesStream targetStream; + + private final ZipOutputStream zipOutputStream = new ZipOutputStream(new OutputStream() { + @Override + public void write(int b) throws IOException { + assert targetStream != null; + targetStream.write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + assert targetStream != null; + targetStream.write(b, off, len); + } + }, StandardCharsets.UTF_8); + + private final String filename; + private final RestChannel restChannel; + + /** + * A listener for the first part (i.e. sequence of chunks of zipped data) of the next entry to become available for transmission after a + * pause. Completed with the newly-created unique active {@link AvailableChunksZipResponseBodyPart} within {@link #enqueueEntry}, and + * subscribed to via {@link AvailableChunksZipResponseBodyPart#getNextPart} when the current {@link AvailableChunksZipResponseBodyPart} + * becomes inactive because of a transmission pause. + */ + @Nullable // if the first part hasn't been sent yet + private SubscribableListener nextAvailableChunksListener; + + /** + * A resource to be released when the transmission of the current entry is complete. Note that we may complete the transmission of + * multiple entries at the same time, if they are all processed by one call to {@link AvailableChunksZipResponseBodyPart#encodeChunk} + * and transmitted together. + */ + @Nullable // if not currently sending an entry + private Releasable currentEntryReleasable; + + /** + * @param filename The name of the zip file, which appears in the {@code Content-Disposition} HTTP header of the response, and also + * is used as a directory prefix for all entries. + * @param restChannel The {@link RestChannel} on which to send the response. + * @param onCompletion A resource which is released when the transmission is complete. + */ + public ChunkedZipResponse(String filename, RestChannel restChannel, Releasable onCompletion) { + this.filename = filename; + this.restChannel = restChannel; + this.listenersRefs = AbstractRefCounted.of(() -> enqueueEntry(null, NO_MORE_ENTRIES, onCompletion)); + this.rootListenerRef = Releasables.releaseOnce(listenersRefs::decRef); + } + + private final RefCounted listenersRefs; + private final Releasable rootListenerRef; + + /** + * Close this {@link ChunkedZipResponse}. Once closed, when there are no more pending listeners the zip file footer is sent. + */ + @Override + public void close() { + rootListenerRef.close(); + } + + /** + * Create a listener which, when completed, will write the result {@link ChunkedRestResponseBodyPart}, and any following parts, as an + * entry in the response stream with the given name. If the listener is completed successfully with {@code null}, or exceptionally, then + * no entry is sent. When all listeners created by this method have been completed, the zip file footer is sent. + *

+ * This method may be called as long as this {@link ChunkedZipResponse} is not closed, or there is at least one other incomplete entry + * listener. + * + * @param entryName The name of the entry in the response zip file. + * @param releasable A resource which is released when the entry has been completely processed, i.e. when + *

    + *
  • the sequence of {@link ChunkedRestResponseBodyPart} instances have been fully sent, or
  • + *
  • the listener was completed with {@code null}, or an exception, indicating that no entry is to be sent, or
  • + *
  • the overall response was cancelled before completion and all resources related to the partial transmission of + * this entry have been released.
  • + *
+ */ + public ActionListener newEntryListener(String entryName, Releasable releasable) { + if (listenersRefs.tryIncRef()) { + final var zipEntry = new ZipEntry(filename + "/" + entryName); + return ActionListener.assertOnce(ActionListener.releaseAfter(new ActionListener<>() { + @Override + public void onResponse(ChunkedRestResponseBodyPart chunkedRestResponseBodyPart) { + if (chunkedRestResponseBodyPart == null) { + Releasables.closeExpectNoException(releasable); + } else { + enqueueEntry(zipEntry, chunkedRestResponseBodyPart, releasable); + } + } + + @Override + public void onFailure(Exception e) { + Releasables.closeExpectNoException(releasable); + } + + @Override + public String toString() { + return "ZipEntry[" + zipEntry.getName() + "]"; + } + }, listenersRefs::decRef)); + } else { + assert false : "already closed"; + throw new AlreadyClosedException("response already closed"); + } + } + + /** + * A zip file entry which is ready for transmission, to be stored in {@link #entryQueue}. + * + * @param zipEntry The entry metadata, to be written in its header. + * @param firstBodyPart The first part of the entry body. Subsequent parts, if present, come from + * {@link ChunkedRestResponseBodyPart#getNextPart}. + * @param releasable A resource to release when this entry has been fully transmitted, or is no longer required because the + * transmission was cancelled. + */ + private record ChunkedZipEntry(ZipEntry zipEntry, ChunkedRestResponseBodyPart firstBodyPart, Releasable releasable) {} + + /** + * Queue of entries that are ready for transmission. + */ + private final Queue entryQueue = new LinkedBlockingQueue<>(); + + /** + * Upper bound on the number of entries in the queue, atomically modified to ensure there's only one thread processing queue entries at + * once. + */ + private final AtomicInteger queueLength = new AtomicInteger(); + + /** + * Ref-counting for access to the queue, to avoid clearing the queue on abort concurrently with an entry being sent. + */ + private final RefCounted queueRefs = AbstractRefCounted.of(this::drainQueue); + + /** + * Flag to indicate if the request has been aborted, at which point we should stop enqueueing more entries and promptly clean up the + * ones being sent. It's safe to ignore this, but without it in theory a constant stream of calls to {@link #enqueueEntry} could prevent + * {@link #drainQueue} from running for arbitrarily long. + */ + private final AtomicBoolean isRestResponseFinished = new AtomicBoolean(); + + private boolean tryAcquireQueueRef() { + return isRestResponseFinished.get() == false && queueRefs.tryIncRef(); + } + + /** + * Called when an entry is ready for its transmission to start. Adds the entry to {@link #entryQueue} and spawns a new + * {@link AvailableChunksZipResponseBodyPart} if none is currently active. + * + * @param zipEntry The entry metadata. + * @param firstBodyPart The first part of the entry. Entries may comprise multiple parts, with transmission pauses in between. + * @param releasable Released when the entry has been fully transmitted. + */ + private void enqueueEntry(ZipEntry zipEntry, ChunkedRestResponseBodyPart firstBodyPart, Releasable releasable) { + if (tryAcquireQueueRef()) { + try { + entryQueue.add(new ChunkedZipEntry(zipEntry, firstBodyPart, releasable)); + if (queueLength.getAndIncrement() == 0) { + // There is no active AvailableChunksZipResponseBodyPart, but there is now an entry in the queue, so we must create a + // AvailableChunksZipResponseBodyPart to process it (along with any other entries that are concurrently added to the + // queue). It's safe to mutate releasable and continuationListener here because they are only otherwise accessed by an + // active AvailableChunksZipResponseBodyPart (which does not exist) or when all queueRefs have been released (which they + // have not here). + final var nextEntry = entryQueue.poll(); + assert nextEntry != null; + final var availableChunks = new AvailableChunksZipResponseBodyPart(nextEntry.zipEntry(), nextEntry.firstBodyPart()); + assert currentEntryReleasable == null; + currentEntryReleasable = nextEntry.releasable(); + final var currentAvailableChunksListener = nextAvailableChunksListener; + nextAvailableChunksListener = new SubscribableListener<>(); + if (currentAvailableChunksListener == null) { + // We are not resuming after a pause, this is the first entry to be sent, so we start the response transmission. + final var restResponse = RestResponse.chunked(RestStatus.OK, availableChunks, this::restResponseFinished); + restResponse.addHeader("content-disposition", Strings.format("attachment; filename=\"%s.zip\"", filename)); + restChannel.sendResponse(restResponse); + } else { + // We are resuming transmission after a pause, so just carry on sending the response body. + currentAvailableChunksListener.onResponse(availableChunks); + } + } + } finally { + queueRefs.decRef(); + } + } else { + Releasables.closeExpectNoException(releasable); + } + } + + private void restResponseFinished() { + assert Transports.assertTransportThread(); + if (isRestResponseFinished.compareAndSet(false, true)) { + queueRefs.decRef(); + } + } + + private void drainQueue() { + assert isRestResponseFinished.get(); + assert queueRefs.hasReferences() == false; + final var taskCount = queueLength.get() + 1; + final var releasables = new ArrayList(taskCount); + try { + releasables.add(currentEntryReleasable); + currentEntryReleasable = null; + ChunkedZipEntry entry; + while ((entry = entryQueue.poll()) != null) { + releasables.add(entry.releasable()); + } + assert entryQueue.isEmpty() : entryQueue.size(); // no concurrent adds + assert releasables.size() == taskCount || releasables.size() == taskCount - 1 : taskCount + " vs " + releasables.size(); + } finally { + Releasables.closeExpectNoException(Releasables.wrap(releasables)); + } + } + + /** + * A {@link ChunkedRestResponseBodyPart} which will yield all currently-available chunks by consuming entries from {@link #entryQueue}. + * There is only ever at most one active instance of this class at any time, in the sense that one such instance becoming inactive + * happens-before the creation of the next instance. One of these parts may send chunks for more than one entry. + */ + private final class AvailableChunksZipResponseBodyPart implements ChunkedRestResponseBodyPart { + + /** + * The next {@link ZipEntry} header whose transmission to start. + */ + @Nullable // if no entry is available, or we've already sent the header for the current entry and are now sending its body. + private ZipEntry zipEntry; + + /** + * The body part which is currently being transmitted, or {@link #NO_MORE_ENTRIES} if we're transmitting the zip file footer. + */ + private ChunkedRestResponseBodyPart bodyPart; + + /** + * True when we have run out of compressed chunks ready for immediate transmission, so the response is paused, but we expect to send + * more data later. + */ + private boolean isResponsePaused; + + /** + * True when we have sent the zip file footer, or the response was cancelled. + */ + private boolean isResponseComplete; + + /** + * A listener which is created when there are no more available chunks, so transmission is paused, subscribed to in + * {@link #getNextPart}, and then completed with the next body part (sequence of zipped chunks, i.e. a new (unique) active + * {@link AvailableChunksZipResponseBodyPart}). + */ + private SubscribableListener getNextPartListener; + + /** + * A cache for an empty list to be used to collect the {@code Releasable} instances to be released when the next chunk has been + * fully transmitted. It's a list because a call to {@link #encodeChunk} may yield a chunk that completes several entries, each of + * which has its own resources to release. We cache this value across chunks because most chunks won't release anything, so we can + * keep the empty list around for later to save on allocations. + */ + private ArrayList nextReleasablesCache = new ArrayList<>(); + + AvailableChunksZipResponseBodyPart(ZipEntry zipEntry, ChunkedRestResponseBodyPart bodyPart) { + this.zipEntry = zipEntry; + this.bodyPart = bodyPart; + } + + /** + * @return whether this part of the compressed response is complete + */ + @Override + public boolean isPartComplete() { + return isResponsePaused || isResponseComplete; + } + + @Override + public boolean isLastPart() { + return isResponseComplete; + } + + @Override + public void getNextPart(ActionListener listener) { + assert getNextPartListener != null; + getNextPartListener.addListener(listener); + } + + /** + * Transfer {@link #currentEntryReleasable} into the supplied collection (i.e. add it to {@code releasables} and then clear + * {@link #currentEntryReleasable}). Called when the last chunk of the last part of the current entry is serialized, so that we can + * start serializing chunks of the next entry straight away whilst delaying the release of the current entry's resources until the + * transmission of the chunk that is currently under construction. + */ + private void transferCurrentEntryReleasable(ArrayList releasables) { + assert queueRefs.hasReferences(); + + if (currentEntryReleasable == null) { + return; + } + + if (releasables == nextReleasablesCache) { + // adding the first value, so we must line up a new cached value for the next caller + nextReleasablesCache = new ArrayList<>(); + } + + releasables.add(currentEntryReleasable); + currentEntryReleasable = null; + } + + @Override + public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) throws IOException { + assert Transports.isTransportThread(Thread.currentThread()); + + final ArrayList releasables = nextReleasablesCache; + assert releasables.isEmpty(); + try { + if (tryAcquireQueueRef()) { + try { + assert queueLength.get() > 0; + // This is the current unique active AvailableChunksZipResponseBodyPart (i.e. queueLength is strictly positive and + // we hold a queueRef), so any concurrent calls to enqueueEntry() at this point will just add to the queue and won't + // spawn a new AvailableChunksZipResponseBodyPart or mutate any fields. + + final RecyclerBytesStreamOutput chunkStream = new RecyclerBytesStreamOutput(recycler); + assert targetStream == null; + targetStream = chunkStream; + + do { + writeNextBytes(sizeHint, recycler, releasables); + } while (isResponseComplete == false && isResponsePaused == false && chunkStream.size() < sizeHint); + + assert (releasables == nextReleasablesCache) == releasables.isEmpty(); + assert nextReleasablesCache.isEmpty(); + + final Releasable chunkStreamReleasable = () -> Releasables.closeExpectNoException(chunkStream); + final var result = new ReleasableBytesReference( + chunkStream.bytes(), + releasables.isEmpty() + ? chunkStreamReleasable + : Releasables.wrap(Iterators.concat(Iterators.single(chunkStreamReleasable), releasables.iterator())) + ); + + targetStream = null; + return result; + } finally { + queueRefs.decRef(); + } + } else { + // request aborted, nothing more to send (queue is being cleared by queueRefs#closeInternal) + isResponseComplete = true; + return new ReleasableBytesReference(BytesArray.EMPTY, () -> {}); + } + } catch (Exception e) { + logger.error("failure encoding chunk", e); + throw e; + } finally { + if (targetStream != null) { + assert false : "failure encoding chunk"; + IOUtils.closeWhileHandlingException(targetStream, Releasables.wrap(releasables)); + targetStream = null; + } + } + } + + private void writeNextBytes(int sizeHint, Recycler recycler, ArrayList releasables) throws IOException { + try { + if (bodyPart == NO_MORE_ENTRIES) { + // When the last ref from listenersRefs is completed we enqueue a final sentinel entry to trigger the transmission of + // the zip file footer, which happens here: + finishResponse(releasables); + return; + } + + if (zipEntry != null) { + // This is the start of a new entry, so write the entry header: + zipOutputStream.putNextEntry(zipEntry); + zipEntry = null; + } + + // Write the next chunk of the current entry to the zip stream + if (bodyPart.isPartComplete() == false) { + try (var innerChunk = bodyPart.encodeChunk(sizeHint, recycler)) { + final var iterator = innerChunk.iterator(); + BytesRef bytesRef; + while ((bytesRef = iterator.next()) != null) { + zipOutputStream.write(bytesRef.bytes, bytesRef.offset, bytesRef.length); + } + } + } + if (bodyPart.isPartComplete()) { + // Complete the current part: if the current entry is incomplete then set up a listener for its next part, otherwise + // move on to the next available entry and start sending its content. + finishCurrentPart(releasables); + } + } finally { + // Flush any buffered data (but not the compressor) to chunkStream so that its size is accurate. + zipOutputStream.flush(); + } + } + + private void finishCurrentPart(ArrayList releasables) throws IOException { + if (bodyPart.isLastPart()) { + zipOutputStream.closeEntry(); + transferCurrentEntryReleasable(releasables); + final var newQueueLength = queueLength.decrementAndGet(); + if (newQueueLength == 0) { + // The current entry is complete, but the next entry isn't available yet, so we pause transmission. This means we are no + // longer an active AvailableChunksZipResponseBodyPart, so any concurrent calls to enqueueEntry() at this point will now + // spawn a new AvailableChunksZipResponseBodyPart to take our place. + isResponsePaused = true; + assert getNextPartListener == null; + assert nextAvailableChunksListener != null; + // Calling our getNextPart() will eventually yield the next body part supplied to enqueueEntry(): + getNextPartListener = nextAvailableChunksListener; + } else { + // The current entry is complete, and the first part of the next entry is already available, so we start sending its + // chunks too. This means we're still the unique active AvailableChunksZipResponseBodyPart. We re-use this + // AvailableChunksZipResponseBodyPart instance rather than creating a new one to avoid unnecessary allocations. + final var nextEntry = entryQueue.poll(); + assert nextEntry != null; + zipEntry = nextEntry.zipEntry(); + bodyPart = nextEntry.firstBodyPart(); + currentEntryReleasable = nextEntry.releasable(); + } + } else { + // The current entry has more parts to come, but we have reached the end of the current part, so we assume that the next + // part is not yet available and therefore must pause transmission. This means we are no longer an active + // AvailableChunksZipResponseBodyPart, but also another call to enqueueEntry() won't create a new + // AvailableChunksZipResponseBodyPart because the current entry is still counted in queueLength: + assert queueLength.get() > 0; + // Instead, we create a new active AvailableChunksZipResponseBodyPart when the next part of the current entry becomes + // available. It doesn't affect correctness if the next part is already available, it's just a little less efficient to make + // a new AvailableChunksZipResponseBodyPart in that case. That's ok, entries can coalesce all the available parts together + // themselves if efficiency really matters. + isResponsePaused = true; + assert getNextPartListener == null; + // Calling our getNextPart() will eventually yield the next body part from the current entry: + getNextPartListener = SubscribableListener.newForked( + l -> bodyPart.getNextPart(l.map(p -> new AvailableChunksZipResponseBodyPart(null, p))) + ); + } + } + + private void finishResponse(ArrayList releasables) throws IOException { + assert zipEntry == null; + assert entryQueue.isEmpty() : entryQueue.size(); + zipOutputStream.finish(); + isResponseComplete = true; + transferCurrentEntryReleasable(releasables); + assert getNextPartListener == null; + } + + @Override + public String getResponseContentTypeString() { + return ZIP_CONTENT_TYPE; + } + } + + /** + * Sentinel body part indicating the end of the zip file. + */ + private static final ChunkedRestResponseBodyPart NO_MORE_ENTRIES = new ChunkedRestResponseBodyPart() { + @Override + public boolean isPartComplete() { + assert false : "never called"; + return true; + } + + @Override + public boolean isLastPart() { + assert false : "never called"; + return true; + } + + @Override + public void getNextPart(ActionListener listener) { + assert false : "never called"; + listener.onFailure(new IllegalStateException("impossible")); + } + + @Override + public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) { + assert false : "never called"; + return ReleasableBytesReference.empty(); + } + + @Override + public String getResponseContentTypeString() { + assert false : "never called"; + return ZIP_CONTENT_TYPE; + } + }; +}