Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ public void testEmptyContent() throws Exception {
assertTrue(recvChunk.isLast);
assertEquals(0, recvChunk.chunk.length());
recvChunk.chunk.close();
assertFalse(handler.streamClosed);

// send response to process following request
handler.sendResponse(new RestResponse(RestStatus.OK, ""));
assertBusy(() -> assertTrue(handler.streamClosed));
}
assertBusy(() -> assertEquals("should receive all server responses", totalRequests, ctx.clientRespQueue.size()));
}
Expand Down Expand Up @@ -145,14 +147,16 @@ public void testReceiveAllChunks() throws Exception {
}
}

assertFalse(handler.streamClosed);
assertEquals("sent and received payloads are not the same", sendData, recvData);
handler.sendResponse(new RestResponse(RestStatus.OK, ""));
assertBusy(() -> assertTrue(handler.streamClosed));
}
assertBusy(() -> assertEquals("should receive all server responses", totalRequests, ctx.clientRespQueue.size()));
}
}

// ensures that all received chunks are released when connection closed
// ensures that all received chunks are released when connection closed and handler notified
public void testClientConnectionCloseMidStream() throws Exception {
try (var ctx = setupClientCtx()) {
var opaqueId = opaqueId(0);
Expand All @@ -167,10 +171,14 @@ public void testClientConnectionCloseMidStream() throws Exception {

// enable auto-read to receive channel close event
handler.stream.channel().config().setAutoRead(true);
assertFalse(handler.streamClosed);

// terminate connection and wait resources are released
ctx.clientChannel.close();
assertBusy(() -> assertNull(handler.stream.buf()));
assertBusy(() -> {
assertNull(handler.stream.buf());
assertTrue(handler.streamClosed);
});
}
}

Expand All @@ -186,10 +194,14 @@ public void testServerCloseConnectionMidStream() throws Exception {
// await stream handler is ready and request full content
var handler = ctx.awaitRestChannelAccepted(opaqueId);
assertBusy(() -> assertNotNull(handler.stream.buf()));
assertFalse(handler.streamClosed);

// terminate connection on server and wait resources are released
handler.channel.request().getHttpChannel().close();
assertBusy(() -> assertNull(handler.stream.buf()));
assertBusy(() -> {
assertNull(handler.stream.buf());
assertTrue(handler.streamClosed);
});
}
}

Expand Down Expand Up @@ -470,6 +482,7 @@ static class ServerRequestHandler implements BaseRestHandler.RequestBodyChunkCon
final Netty4HttpRequestBodyStream stream;
RestChannel channel;
boolean recvLast = false;
volatile boolean streamClosed = false;

ServerRequestHandler(String opaqueId, Netty4HttpRequestBodyStream stream) {
this.opaqueId = opaqueId;
Expand All @@ -487,6 +500,11 @@ public void accept(RestChannel channel) throws Exception {
channelAccepted.onResponse(null);
}

@Override
public void streamClose() {
streamClosed = true;
}

void sendResponse(RestResponse response) {
channel.sendResponse(response);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ public void close() {

private void doClose() {
closing = true;
if (handler != null) {
handler.close();
}
if (buf != null) {
buf.release();
buf = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,29 @@ public void testSingleBulkRequest() {
assertFalse(refCounted.hasReferences());
}

public void testBufferedResourcesReleasedOnClose() {
String index = "test";
createIndex(index);

String nodeName = internalCluster().getRandomNodeName();
IncrementalBulkService incrementalBulkService = internalCluster().getInstance(IncrementalBulkService.class, nodeName);
IndexingPressure indexingPressure = internalCluster().getInstance(IndexingPressure.class, nodeName);

IncrementalBulkService.Handler handler = incrementalBulkService.newBulkRequest();
IndexRequest indexRequest = indexRequest(index);

AbstractRefCounted refCounted = AbstractRefCounted.of(() -> {});
handler.addItems(List.of(indexRequest), refCounted::decRef, () -> {});

assertTrue(refCounted.hasReferences());
assertThat(indexingPressure.stats().getCurrentCoordinatingBytes(), greaterThan(0L));

handler.close();

assertFalse(refCounted.hasReferences());
assertThat(indexingPressure.stats().getCurrentCoordinatingBytes(), equalTo(0L));
}

public void testIndexingPressureRejection() {
String index = "test";
createIndex(index);
Expand Down Expand Up @@ -303,14 +326,20 @@ public void testShortCircuitShardLevelFailure() throws Exception {
String secondShardNode = findShard(resolveIndex(index), 1);
IndexingPressure primaryPressure = internalCluster().getInstance(IndexingPressure.class, node);
long memoryLimit = primaryPressure.stats().getMemoryLimit();
long primaryRejections = primaryPressure.stats().getPrimaryRejections();
try (Releasable releasable = primaryPressure.markPrimaryOperationStarted(10, memoryLimit, false)) {
while (nextRequested.get()) {
nextRequested.set(false);
refCounted.incRef();
handler.addItems(List.of(indexRequest(index)), refCounted::decRef, () -> nextRequested.set(true));
while (primaryPressure.stats().getPrimaryRejections() == primaryRejections) {
while (nextRequested.get()) {
nextRequested.set(false);
refCounted.incRef();
List<DocWriteRequest<?>> requests = new ArrayList<>();
for (int i = 0; i < 20; ++i) {
requests.add(indexRequest(index));
}
handler.addItems(requests, refCounted::decRef, () -> nextRequested.set(true));
}
assertBusy(() -> assertTrue(nextRequested.get()));
}

assertBusy(() -> assertTrue(nextRequested.get()));
}

while (nextRequested.get()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public static class Handler implements Releasable {

private final ArrayList<Releasable> releasables = new ArrayList<>(4);
private final ArrayList<BulkResponse> responses = new ArrayList<>(2);
private boolean closed = false;
private boolean globalFailure = false;
private boolean incrementalRequestSubmitted = false;
private ThreadContext.StoredContext requestContext;
Expand All @@ -126,6 +127,7 @@ protected Handler(
}

public void addItems(List<DocWriteRequest<?>> items, Releasable releasable, Runnable nextItems) {
assert closed == false;
if (bulkActionLevelFailure != null) {
shortCircuitDueToTopLevelFailure(items, releasable);
nextItems.run();
Expand All @@ -137,12 +139,13 @@ public void addItems(List<DocWriteRequest<?>> items, Releasable releasable, Runn
incrementalRequestSubmitted = true;
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
requestContext.restore();
final ArrayList<Releasable> toRelease = new ArrayList<>(releasables);
releasables.clear();
client.bulk(bulkRequest, ActionListener.runAfter(new ActionListener<>() {

@Override
public void onResponse(BulkResponse bulkResponse) {
responses.add(bulkResponse);
releaseCurrentReferences();
handleBulkSuccess(bulkResponse);
createNewBulkRequest(
new BulkRequest.IncrementalState(bulkResponse.getIncrementalState().shardLevelFailures(), true)
);
Expand All @@ -154,6 +157,7 @@ public void onFailure(Exception e) {
}
}, () -> {
requestContext = threadContext.newStoredContext();
toRelease.forEach(Releasable::close);
nextItems.run();
}));
}
Expand All @@ -179,14 +183,15 @@ public void lastItems(List<DocWriteRequest<?>> items, Releasable releasable, Act
if (internalAddItems(items, releasable)) {
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
requestContext.restore();
client.bulk(bulkRequest, new ActionListener<>() {
final ArrayList<Releasable> toRelease = new ArrayList<>(releasables);
releasables.clear();
client.bulk(bulkRequest, ActionListener.runBefore(new ActionListener<>() {

private final boolean isFirstRequest = incrementalRequestSubmitted == false;

@Override
public void onResponse(BulkResponse bulkResponse) {
responses.add(bulkResponse);
releaseCurrentReferences();
handleBulkSuccess(bulkResponse);
listener.onResponse(combineResponses());
}

Expand All @@ -195,14 +200,21 @@ public void onFailure(Exception e) {
handleBulkFailure(isFirstRequest, e);
errorResponse(listener);
}
});
}, () -> toRelease.forEach(Releasable::close)));
}
} else {
errorResponse(listener);
}
}
}

@Override
public void close() {
closed = true;
releasables.forEach(Releasable::close);
releasables.clear();
}

private void shortCircuitDueToTopLevelFailure(List<DocWriteRequest<?>> items, Releasable releasable) {
assert releasables.isEmpty();
assert bulkRequest == null;
Expand All @@ -220,12 +232,17 @@ private void errorResponse(ActionListener<BulkResponse> listener) {
}
}

private void handleBulkSuccess(BulkResponse bulkResponse) {
responses.add(bulkResponse);
bulkRequest = null;
}

private void handleBulkFailure(boolean isFirstRequest, Exception e) {
assert bulkActionLevelFailure == null;
globalFailure = isFirstRequest;
bulkActionLevelFailure = e;
addItemLevelFailures(bulkRequest.requests());
releaseCurrentReferences();
bulkRequest = null;
}

private void addItemLevelFailures(List<DocWriteRequest<?>> items) {
Expand Down Expand Up @@ -253,6 +270,8 @@ private boolean internalAddItems(List<DocWriteRequest<?>> items, Releasable rele
return true;
} catch (EsRejectedExecutionException e) {
handleBulkFailure(incrementalRequestSubmitted == false, e);
releasables.forEach(Releasable::close);
releasables.clear();
return false;
}
}
Expand Down Expand Up @@ -297,10 +316,5 @@ private BulkResponse combineResponses() {

return new BulkResponse(bulkItemResponses, tookInMillis, ingestTookInMillis);
}

@Override
public void close() {
// TODO: Implement
}
}
}
5 changes: 4 additions & 1 deletion server/src/main/java/org/elasticsearch/http/HttpBody.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,11 @@ non-sealed interface Stream extends HttpBody {
}

@FunctionalInterface
interface ChunkHandler {
interface ChunkHandler extends Releasable {
void onNext(ReleasableBytesReference chunk, boolean isLast);

@Override
default void close() {}
}

record ByteRefHttpBody(BytesReference bytes) implements Full {}
Expand Down
20 changes: 19 additions & 1 deletion server/src/main/java/org/elasticsearch/rest/BaseRestHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.RestApiVersion;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.http.HttpBody;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.rest.action.admin.cluster.RestNodesUsageAction;

Expand Down Expand Up @@ -126,7 +127,17 @@ public final void handleRequest(RestRequest request, RestChannel channel, NodeCl
if (request.isStreamedContent()) {
assert action instanceof RequestBodyChunkConsumer;
var chunkConsumer = (RequestBodyChunkConsumer) action;
request.contentStream().setHandler((chunk, isLast) -> chunkConsumer.handleChunk(channel, chunk, isLast));
request.contentStream().setHandler(new HttpBody.ChunkHandler() {
@Override
public void onNext(ReleasableBytesReference chunk, boolean isLast) {
chunkConsumer.handleChunk(channel, chunk, isLast);
}

@Override
public void close() {
chunkConsumer.streamClose();
}
});
}

usageCount.increment();
Expand Down Expand Up @@ -188,6 +199,13 @@ default void close() {}

public interface RequestBodyChunkConsumer extends RestChannelConsumer {
void handleChunk(RestChannel channel, ReleasableBytesReference chunk, boolean isLast);

/**
* Called when the stream closes. This could happen prior to the completion of the request if the underlying channel was closed.
* Implementors should do their best to clean up resources and early terminate request processing if it is triggered before a
* response is generated.
*/
default void streamClose() {}
Comment thread
mhl-b marked this conversation as resolved.
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.rest.action.RestRefCountedChunkedToXContentListener;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.transport.Transports;

import java.io.IOException;
import java.util.ArrayDeque;
Expand Down Expand Up @@ -147,7 +148,7 @@ static class ChunkHandler implements BaseRestHandler.RequestBodyChunkConsumer {
private IncrementalBulkService.Handler handler;

private volatile RestChannel restChannel;
private boolean isException;
private boolean shortCircuited;
private final ArrayDeque<ReleasableBytesReference> unParsedChunks = new ArrayDeque<>(4);
private final ArrayList<DocWriteRequest<?>> items = new ArrayList<>(4);

Expand Down Expand Up @@ -177,7 +178,7 @@ public void accept(RestChannel restChannel) {
public void handleChunk(RestChannel channel, ReleasableBytesReference chunk, boolean isLast) {
assert handler != null;
assert channel == restChannel;
if (isException) {
if (shortCircuited) {
chunk.close();
return;
}
Expand Down Expand Up @@ -214,12 +215,8 @@ public void handleChunk(RestChannel channel, ReleasableBytesReference chunk, boo
);

} catch (Exception e) {
// TODO: This needs to be better
Releasables.close(handler);
Releasables.close(unParsedChunks);
unParsedChunks.clear();
shortCircuit();
new RestToXContentListener<>(channel).onFailure(e);
isException = true;
return;
}

Expand All @@ -241,8 +238,16 @@ public void handleChunk(RestChannel channel, ReleasableBytesReference chunk, boo
}

@Override
public void close() {
RequestBodyChunkConsumer.super.close();
public void streamClose() {
assert Transports.assertTransportThread();
shortCircuit();
}

private void shortCircuit() {
shortCircuited = true;
Releasables.close(handler);
Releasables.close(unParsedChunks);
unParsedChunks.clear();
}

private ArrayList<Releasable> accountParsing(int bytesConsumed) {
Expand Down