diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java index a401d84b59aca..ffb5781aae795 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java @@ -110,9 +110,11 @@ public void testEmptyContent() throws Exception { assertTrue(recvChunk.isLast); assertEquals(0, recvChunk.chunk.length()); recvChunk.chunk.close(); + assertFalse(handler.streamClosed); // send response to process following request handler.sendResponse(new RestResponse(RestStatus.OK, "")); + assertBusy(() -> assertTrue(handler.streamClosed)); } assertBusy(() -> assertEquals("should receive all server responses", totalRequests, ctx.clientRespQueue.size())); } @@ -145,14 +147,16 @@ public void testReceiveAllChunks() throws Exception { } } + assertFalse(handler.streamClosed); assertEquals("sent and received payloads are not the same", sendData, recvData); handler.sendResponse(new RestResponse(RestStatus.OK, "")); + assertBusy(() -> assertTrue(handler.streamClosed)); } assertBusy(() -> assertEquals("should receive all server responses", totalRequests, ctx.clientRespQueue.size())); } } - // ensures that all received chunks are released when connection closed + // ensures that all received chunks are released when connection closed and handler notified public void testClientConnectionCloseMidStream() throws Exception { try (var ctx = setupClientCtx()) { var opaqueId = opaqueId(0); @@ -167,10 +171,14 @@ public void testClientConnectionCloseMidStream() throws Exception { // enable auto-read to receive channel close event handler.stream.channel().config().setAutoRead(true); + assertFalse(handler.streamClosed); // terminate connection and wait resources are released ctx.clientChannel.close(); - assertBusy(() -> assertNull(handler.stream.buf())); + assertBusy(() -> { + assertNull(handler.stream.buf()); + assertTrue(handler.streamClosed); + }); } } @@ -186,10 +194,14 @@ public void testServerCloseConnectionMidStream() throws Exception { // await stream handler is ready and request full content var handler = ctx.awaitRestChannelAccepted(opaqueId); assertBusy(() -> assertNotNull(handler.stream.buf())); + assertFalse(handler.streamClosed); // terminate connection on server and wait resources are released handler.channel.request().getHttpChannel().close(); - assertBusy(() -> assertNull(handler.stream.buf())); + assertBusy(() -> { + assertNull(handler.stream.buf()); + assertTrue(handler.streamClosed); + }); } } @@ -470,6 +482,7 @@ static class ServerRequestHandler implements BaseRestHandler.RequestBodyChunkCon final Netty4HttpRequestBodyStream stream; RestChannel channel; boolean recvLast = false; + volatile boolean streamClosed = false; ServerRequestHandler(String opaqueId, Netty4HttpRequestBodyStream stream) { this.opaqueId = opaqueId; @@ -487,6 +500,11 @@ public void accept(RestChannel channel) throws Exception { channelAccepted.onResponse(null); } + @Override + public void streamClose() { + streamClosed = true; + } + void sendResponse(RestResponse response) { channel.sendResponse(response); } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStream.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStream.java index 288a46c638dbb..31abf93557574 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStream.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStream.java @@ -131,6 +131,9 @@ public void close() { private void doClose() { closing = true; + if (handler != null) { + handler.close(); + } if (buf != null) { buf.release(); buf = null; diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/IncrementalBulkIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/IncrementalBulkIT.java index f11525fc6bff8..34915bb31651d 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/IncrementalBulkIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/IncrementalBulkIT.java @@ -90,6 +90,29 @@ public void testSingleBulkRequest() { assertFalse(refCounted.hasReferences()); } + public void testBufferedResourcesReleasedOnClose() { + String index = "test"; + createIndex(index); + + String nodeName = internalCluster().getRandomNodeName(); + IncrementalBulkService incrementalBulkService = internalCluster().getInstance(IncrementalBulkService.class, nodeName); + IndexingPressure indexingPressure = internalCluster().getInstance(IndexingPressure.class, nodeName); + + IncrementalBulkService.Handler handler = incrementalBulkService.newBulkRequest(); + IndexRequest indexRequest = indexRequest(index); + + AbstractRefCounted refCounted = AbstractRefCounted.of(() -> {}); + handler.addItems(List.of(indexRequest), refCounted::decRef, () -> {}); + + assertTrue(refCounted.hasReferences()); + assertThat(indexingPressure.stats().getCurrentCoordinatingBytes(), greaterThan(0L)); + + handler.close(); + + assertFalse(refCounted.hasReferences()); + assertThat(indexingPressure.stats().getCurrentCoordinatingBytes(), equalTo(0L)); + } + public void testIndexingPressureRejection() { String index = "test"; createIndex(index); @@ -303,14 +326,20 @@ public void testShortCircuitShardLevelFailure() throws Exception { String secondShardNode = findShard(resolveIndex(index), 1); IndexingPressure primaryPressure = internalCluster().getInstance(IndexingPressure.class, node); long memoryLimit = primaryPressure.stats().getMemoryLimit(); + long primaryRejections = primaryPressure.stats().getPrimaryRejections(); try (Releasable releasable = primaryPressure.markPrimaryOperationStarted(10, memoryLimit, false)) { - while (nextRequested.get()) { - nextRequested.set(false); - refCounted.incRef(); - handler.addItems(List.of(indexRequest(index)), refCounted::decRef, () -> nextRequested.set(true)); + while (primaryPressure.stats().getPrimaryRejections() == primaryRejections) { + while (nextRequested.get()) { + nextRequested.set(false); + refCounted.incRef(); + List> requests = new ArrayList<>(); + for (int i = 0; i < 20; ++i) { + requests.add(indexRequest(index)); + } + handler.addItems(requests, refCounted::decRef, () -> nextRequested.set(true)); + } + assertBusy(() -> assertTrue(nextRequested.get())); } - - assertBusy(() -> assertTrue(nextRequested.get())); } while (nextRequested.get()) { diff --git a/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java b/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java index 3e006bc960f84..bac87d563fcb5 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java @@ -101,6 +101,7 @@ public static class Handler implements Releasable { private final ArrayList releasables = new ArrayList<>(4); private final ArrayList responses = new ArrayList<>(2); + private boolean closed = false; private boolean globalFailure = false; private boolean incrementalRequestSubmitted = false; private ThreadContext.StoredContext requestContext; @@ -126,6 +127,7 @@ protected Handler( } public void addItems(List> items, Releasable releasable, Runnable nextItems) { + assert closed == false; if (bulkActionLevelFailure != null) { shortCircuitDueToTopLevelFailure(items, releasable); nextItems.run(); @@ -137,12 +139,13 @@ public void addItems(List> items, Releasable releasable, Runn incrementalRequestSubmitted = true; try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { requestContext.restore(); + final ArrayList toRelease = new ArrayList<>(releasables); + releasables.clear(); client.bulk(bulkRequest, ActionListener.runAfter(new ActionListener<>() { @Override public void onResponse(BulkResponse bulkResponse) { - responses.add(bulkResponse); - releaseCurrentReferences(); + handleBulkSuccess(bulkResponse); createNewBulkRequest( new BulkRequest.IncrementalState(bulkResponse.getIncrementalState().shardLevelFailures(), true) ); @@ -154,6 +157,7 @@ public void onFailure(Exception e) { } }, () -> { requestContext = threadContext.newStoredContext(); + toRelease.forEach(Releasable::close); nextItems.run(); })); } @@ -179,14 +183,15 @@ public void lastItems(List> items, Releasable releasable, Act if (internalAddItems(items, releasable)) { try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { requestContext.restore(); - client.bulk(bulkRequest, new ActionListener<>() { + final ArrayList toRelease = new ArrayList<>(releasables); + releasables.clear(); + client.bulk(bulkRequest, ActionListener.runBefore(new ActionListener<>() { private final boolean isFirstRequest = incrementalRequestSubmitted == false; @Override public void onResponse(BulkResponse bulkResponse) { - responses.add(bulkResponse); - releaseCurrentReferences(); + handleBulkSuccess(bulkResponse); listener.onResponse(combineResponses()); } @@ -195,7 +200,7 @@ public void onFailure(Exception e) { handleBulkFailure(isFirstRequest, e); errorResponse(listener); } - }); + }, () -> toRelease.forEach(Releasable::close))); } } else { errorResponse(listener); @@ -203,6 +208,13 @@ public void onFailure(Exception e) { } } + @Override + public void close() { + closed = true; + releasables.forEach(Releasable::close); + releasables.clear(); + } + private void shortCircuitDueToTopLevelFailure(List> items, Releasable releasable) { assert releasables.isEmpty(); assert bulkRequest == null; @@ -220,12 +232,17 @@ private void errorResponse(ActionListener listener) { } } + private void handleBulkSuccess(BulkResponse bulkResponse) { + responses.add(bulkResponse); + bulkRequest = null; + } + private void handleBulkFailure(boolean isFirstRequest, Exception e) { assert bulkActionLevelFailure == null; globalFailure = isFirstRequest; bulkActionLevelFailure = e; addItemLevelFailures(bulkRequest.requests()); - releaseCurrentReferences(); + bulkRequest = null; } private void addItemLevelFailures(List> items) { @@ -253,6 +270,8 @@ private boolean internalAddItems(List> items, Releasable rele return true; } catch (EsRejectedExecutionException e) { handleBulkFailure(incrementalRequestSubmitted == false, e); + releasables.forEach(Releasable::close); + releasables.clear(); return false; } } @@ -297,10 +316,5 @@ private BulkResponse combineResponses() { return new BulkResponse(bulkItemResponses, tookInMillis, ingestTookInMillis); } - - @Override - public void close() { - // TODO: Implement - } } } diff --git a/server/src/main/java/org/elasticsearch/http/HttpBody.java b/server/src/main/java/org/elasticsearch/http/HttpBody.java index 9da1bc85b2a29..689119e63cafb 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpBody.java +++ b/server/src/main/java/org/elasticsearch/http/HttpBody.java @@ -99,8 +99,11 @@ non-sealed interface Stream extends HttpBody { } @FunctionalInterface - interface ChunkHandler { + interface ChunkHandler extends Releasable { void onNext(ReleasableBytesReference chunk, boolean isLast); + + @Override + default void close() {} } record ByteRefHttpBody(BytesReference bytes) implements Full {} diff --git a/server/src/main/java/org/elasticsearch/rest/BaseRestHandler.java b/server/src/main/java/org/elasticsearch/rest/BaseRestHandler.java index 46a57a4529c49..9bca1cab3f5e4 100644 --- a/server/src/main/java/org/elasticsearch/rest/BaseRestHandler.java +++ b/server/src/main/java/org/elasticsearch/rest/BaseRestHandler.java @@ -19,6 +19,7 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.RestApiVersion; import org.elasticsearch.core.Tuple; +import org.elasticsearch.http.HttpBody; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.rest.action.admin.cluster.RestNodesUsageAction; @@ -126,7 +127,17 @@ public final void handleRequest(RestRequest request, RestChannel channel, NodeCl if (request.isStreamedContent()) { assert action instanceof RequestBodyChunkConsumer; var chunkConsumer = (RequestBodyChunkConsumer) action; - request.contentStream().setHandler((chunk, isLast) -> chunkConsumer.handleChunk(channel, chunk, isLast)); + request.contentStream().setHandler(new HttpBody.ChunkHandler() { + @Override + public void onNext(ReleasableBytesReference chunk, boolean isLast) { + chunkConsumer.handleChunk(channel, chunk, isLast); + } + + @Override + public void close() { + chunkConsumer.streamClose(); + } + }); } usageCount.increment(); @@ -188,6 +199,13 @@ default void close() {} public interface RequestBodyChunkConsumer extends RestChannelConsumer { void handleChunk(RestChannel channel, ReleasableBytesReference chunk, boolean isLast); + + /** + * Called when the stream closes. This could happen prior to the completion of the request if the underlying channel was closed. + * Implementors should do their best to clean up resources and early terminate request processing if it is triggered before a + * response is generated. + */ + default void streamClose() {} } /** diff --git a/server/src/main/java/org/elasticsearch/rest/action/document/RestBulkAction.java b/server/src/main/java/org/elasticsearch/rest/action/document/RestBulkAction.java index a2b5cf47a6efb..96214ec7ca2fe 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/document/RestBulkAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/document/RestBulkAction.java @@ -31,6 +31,7 @@ import org.elasticsearch.rest.action.RestRefCountedChunkedToXContentListener; import org.elasticsearch.rest.action.RestToXContentListener; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.transport.Transports; import java.io.IOException; import java.util.ArrayDeque; @@ -147,7 +148,7 @@ static class ChunkHandler implements BaseRestHandler.RequestBodyChunkConsumer { private IncrementalBulkService.Handler handler; private volatile RestChannel restChannel; - private boolean isException; + private boolean shortCircuited; private final ArrayDeque unParsedChunks = new ArrayDeque<>(4); private final ArrayList> items = new ArrayList<>(4); @@ -177,7 +178,7 @@ public void accept(RestChannel restChannel) { public void handleChunk(RestChannel channel, ReleasableBytesReference chunk, boolean isLast) { assert handler != null; assert channel == restChannel; - if (isException) { + if (shortCircuited) { chunk.close(); return; } @@ -214,12 +215,8 @@ public void handleChunk(RestChannel channel, ReleasableBytesReference chunk, boo ); } catch (Exception e) { - // TODO: This needs to be better - Releasables.close(handler); - Releasables.close(unParsedChunks); - unParsedChunks.clear(); + shortCircuit(); new RestToXContentListener<>(channel).onFailure(e); - isException = true; return; } @@ -241,8 +238,16 @@ public void handleChunk(RestChannel channel, ReleasableBytesReference chunk, boo } @Override - public void close() { - RequestBodyChunkConsumer.super.close(); + public void streamClose() { + assert Transports.assertTransportThread(); + shortCircuit(); + } + + private void shortCircuit() { + shortCircuited = true; + Releasables.close(handler); + Releasables.close(unParsedChunks); + unParsedChunks.clear(); } private ArrayList accountParsing(int bytesConsumed) {