Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 44 additions & 17 deletions server/src/main/java/org/elasticsearch/rest/RestController.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.common.path.PathTrie;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.concurrent.RunOnce;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
Expand All @@ -41,6 +40,7 @@
import org.elasticsearch.rest.RestHandler.Route;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.telemetry.tracing.Tracer;
import org.elasticsearch.transport.Transports;
import org.elasticsearch.usage.SearchUsageHolder;
import org.elasticsearch.usage.UsageService;
import org.elasticsearch.xcontent.XContentBuilder;
Expand All @@ -60,6 +60,7 @@
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import static org.elasticsearch.indices.SystemIndices.EXTERNAL_SYSTEM_INDEX_ACCESS_CONTROL_HEADER_KEY;
Expand Down Expand Up @@ -826,9 +827,13 @@ public void sendResponse(RestResponse response) {
if (response.isChunked() == false) {
methodHandlers.addResponseStats(response.content().length());
} else {
final var wrapped = new EncodedLengthTrackingChunkedRestResponseBody(response.chunkedContent(), methodHandlers);
final var responseLengthRecorder = new ResponseLengthRecorder(methodHandlers);
final var headers = response.getHeaders();
response = RestResponse.chunked(response.status(), wrapped, Releasables.wrap(wrapped, response));
response = RestResponse.chunked(
response.status(),
new EncodedLengthTrackingChunkedRestResponseBody(response.chunkedContent(), responseLengthRecorder),
Releasables.wrap(responseLengthRecorder, response)
);
for (final var header : headers.entrySet()) {
for (final var value : header.getValue()) {
response.addHeader(header.getKey(), value);
Expand Down Expand Up @@ -857,15 +862,44 @@ private void close() {
}
}

private static class EncodedLengthTrackingChunkedRestResponseBody implements ChunkedRestResponseBody, Releasable {
private static class ResponseLengthRecorder extends AtomicReference<MethodHandlers> implements Releasable {
private long responseLength;

private ResponseLengthRecorder(MethodHandlers methodHandlers) {
super(methodHandlers);
}

@Override
public void close() {
// closed just before sending the last chunk, and also when the whole RestResponse is closed since the client might abort the
// connection before we send the last chunk, in which case we won't have recorded the response in the
// stats yet; thus we need run-once semantics here:
final var methodHandlers = getAndSet(null);
if (methodHandlers != null) {
// if we started sending chunks then we're closed on the transport worker, no need for sync
assert responseLength == 0L || Transports.assertTransportThread();
methodHandlers.addResponseStats(responseLength);
}
}

void addChunkLength(long chunkLength) {
assert chunkLength >= 0L : chunkLength;
assert Transports.assertTransportThread(); // always called on the transport worker, no need for sync
responseLength += chunkLength;
}
}

private static class EncodedLengthTrackingChunkedRestResponseBody implements ChunkedRestResponseBody {

private final ChunkedRestResponseBody delegate;
private final RunOnce onCompletion;
private long encodedLength = 0;
private final ResponseLengthRecorder responseLengthRecorder;

private EncodedLengthTrackingChunkedRestResponseBody(ChunkedRestResponseBody delegate, MethodHandlers methodHandlers) {
private EncodedLengthTrackingChunkedRestResponseBody(
ChunkedRestResponseBody delegate,
ResponseLengthRecorder responseLengthRecorder
) {
this.delegate = delegate;
this.onCompletion = new RunOnce(() -> methodHandlers.addResponseStats(encodedLength));
this.responseLengthRecorder = responseLengthRecorder;
}

@Override
Expand All @@ -876,9 +910,9 @@ public boolean isDone() {
@Override
public ReleasableBytesReference encodeChunk(int sizeHint, Recycler<BytesRef> recycler) throws IOException {
final ReleasableBytesReference bytesReference = delegate.encodeChunk(sizeHint, recycler);
encodedLength += bytesReference.length();
responseLengthRecorder.addChunkLength(bytesReference.length());
if (isDone()) {
onCompletion.run();
responseLengthRecorder.close();
}
return bytesReference;
}
Expand All @@ -887,13 +921,6 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler<BytesRef> rec
public String getResponseContentTypeString() {
return delegate.getResponseContentTypeString();
}

@Override
public void close() {
// the client might close the connection before we send the last chunk, in which case we won't have recorded the response in the
// stats yet, so we do it now:
onCompletion.run();
}
}

private static CircuitBreaker inFlightRequestsBreaker(CircuitBreakerService circuitBreakerService) {
Expand Down