diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java index 2127a19a9af99..a45237acd01eb 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java @@ -11,6 +11,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.ScoreMode; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteResponse; @@ -23,6 +24,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; @@ -503,7 +505,10 @@ public void onFailure(Exception e) { Exception.class, client.prepareSearch("test").addAggregation(new TestAggregationBuilder("test")) ); - assertThat(exc.getCause().getMessage(), containsString("")); + assertNotNull( + "root cause must be a CircuitBreakingException", + ExceptionsHelper.unwrap(exc, CircuitBreakingException.class) + ); }); final AtomicArray exceptions = new AtomicArray<>(10); @@ -530,7 +535,10 @@ public void onFailure(Exception exc) { latch.await(); assertThat(exceptions.asList().size(), equalTo(10)); for (Exception exc : exceptions.asList()) { - assertThat(exc.getCause().getMessage(), containsString("")); + assertNotNull( + "root cause must be a CircuitBreakingException", + ExceptionsHelper.unwrap(exc, CircuitBreakingException.class) + ); } assertBusy(() -> assertThat(requestBreakerUsed(), equalTo(0L))); } finally { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index 74ca0c9080ee6..9774ba54d6b90 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.ActionResponse; @@ -21,6 +22,8 @@ import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -50,9 +53,12 @@ import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.AbstractTransportRequest; +import org.elasticsearch.transport.BytesTransportResponse; import org.elasticsearch.transport.RemoteClusterService; +import org.elasticsearch.transport.TaskTransportChannel; import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportActionProxy; +import org.elasticsearch.transport.TransportChannel; import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportRequestHandler; import org.elasticsearch.transport.TransportRequestOptions; @@ -457,7 +463,7 @@ public static void registerRequestHandler( (request, channel, task) -> searchService.executeQueryPhase( request, (SearchShardTask) task, - new ChannelActionListener<>(channel) + channelListener(transportService, channel, searchService.getCircuitBreaker()) ) ); TransportActionProxy.registerProxyActionWithDynamicResponseType( @@ -513,7 +519,7 @@ public static void registerRequestHandler( (request, channel, task) -> searchService.executeFetchPhase( request, (SearchShardTask) task, - new ChannelActionListener<>(channel) + channelListener(transportService, channel, searchService.getCircuitBreaker()) ) ); TransportActionProxy.registerProxyAction( @@ -541,7 +547,11 @@ public static void registerRequestHandler( ); final TransportRequestHandler shardFetchRequestHandler = (request, channel, task) -> searchService - .executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel)); + .executeFetchPhase( + request, + (SearchShardTask) task, + channelListener(transportService, channel, searchService.getCircuitBreaker()) + ); transportService.registerRequestHandler( FETCH_ID_SCROLL_ACTION_NAME, EsExecutors.DIRECT_EXECUTOR_SERVICE, @@ -672,6 +682,75 @@ private boolean assertNodePresent() { } } + /** + * Returns a listener that serializes responses to bytes on the network path. + * + *

On the network path, the response is serialized into bytes using a + * circuit-breaker-aware stream and sent as a {@link BytesTransportResponse}. + * + *

On the direct (same-node) path the response is forwarded as-is. + * + *

Circuit-breaker accounting for response objects is handled by the caller. + */ + static ActionListener channelListener( + TransportService transportService, + TransportChannel channel, + @Nullable CircuitBreaker circuitBreaker + ) { + if (isDirectResponseChannel(channel)) { + return new ChannelActionListener<>(channel); + } + return new NetworkPathListener<>(transportService, channel, circuitBreaker); + } + + private static boolean isDirectResponseChannel(TransportChannel channel) { + if (channel instanceof TaskTransportChannel ttc) { + channel = ttc.getChannel(); + } + return TransportService.isDirectResponseChannel(channel); + } + + /** + * Serializes the response into a {@link BytesTransportResponse} while keeping the breaker-accounted + * bytes alive for the response lifecycle. Captures the transport version from the channel at + * construction time and reuses it for serialization and the response metadata. + */ + private static class NetworkPathListener implements ActionListener { + private final TransportService transportService; + private final TransportVersion transportVersion; + private final ChannelActionListener channelListener; + @Nullable + private final CircuitBreaker circuitBreaker; + + NetworkPathListener(TransportService transportService, TransportChannel channel, @Nullable CircuitBreaker circuitBreaker) { + this.transportService = transportService; + this.transportVersion = channel.getVersion(); + this.channelListener = new ChannelActionListener<>(channel); + this.circuitBreaker = circuitBreaker; + } + + @Override + public void onResponse(T response) { + // The bytes reference keeps breaker-accounted bytes; the stream output closes after serialization. + final ReleasableBytesReference bytesRef; + try (var out = transportService.newNetworkBytesStream(circuitBreaker)) { + out.setTransportVersion(transportVersion); + response.writeTo(out); + bytesRef = out.moveToBytesReference(); + } catch (Exception e) { + channelListener.onFailure(e); + return; + } + // respondAndRelease releases the bytes once the transport layer completes. + ActionListener.respondAndRelease(channelListener, new BytesTransportResponse(bytesRef, transportVersion)); + } + + @Override + public void onFailure(Exception e) { + channelListener.onFailure(e); + } + } + public void cancelSearchTask(SearchTask task, String reason) { CancelTasksRequest req = new CancelTasksRequest().setTargetTaskId(new TaskId(client.getLocalNodeId(), task.getId())) .setReason("Fatal failure during search: " + reason); diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchSearchResult.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchSearchResult.java index d9477a06f9bfe..6703bb5bd1920 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchSearchResult.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchSearchResult.java @@ -28,7 +28,7 @@ public final class FetchSearchResult extends SearchPhaseResult { private SearchHits hits; - private transient long searchHitsSizeBytes = 0L; + private long searchHitsSizeBytes = 0L; // client side counter private transient int counter; diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchTransportServiceChannelListenerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchTransportServiceChannelListenerTests.java new file mode 100644 index 0000000000000..297a48c0ace29 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/search/SearchTransportServiceChannelListenerTests.java @@ -0,0 +1,232 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.action.search; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.transport.MockTransport; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.BytesTransportResponse; +import org.elasticsearch.transport.TestDirectResponseChannel; +import org.elasticsearch.transport.TestTransportChannel; +import org.elasticsearch.transport.TestTransportChannels; +import org.elasticsearch.transport.TransportResponse; +import org.elasticsearch.transport.TransportService; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.cluster.node.DiscoveryNodeUtils.builder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.sameInstance; + +public class SearchTransportServiceChannelListenerTests extends ESTestCase { + + private ThreadPool threadPool; + private TransportService transportService; + + @Before + public void setUpResources() throws Exception { + threadPool = new TestThreadPool(getTestName()); + var mockTransport = new MockTransport(); + transportService = mockTransport.createTransportService( + Settings.EMPTY, + threadPool, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + address -> builder("test-node").build(), + null, + Collections.emptySet() + ); + } + + @After + public void tearDownResources() throws Exception { + transportService.close(); + ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + } + + public void testNetworkPathBytesResponseRoundTrip() throws Exception { + var sentResponse = new AtomicReference(); + + var channel = new TestTransportChannel(ActionListener.wrap(resp -> { + resp.mustIncRef(); + sentResponse.set(resp); + }, e -> fail("unexpected failure: " + e))); + + var original = new SimpleTestResponse("test"); + ActionListener listener = SearchTransportService.channelListener( + transportService, + channel, + newLimitedBreaker(ByteSizeValue.ofMb(100)) + ); + + listener.onResponse(original); + + assertThat(sentResponse.get(), instanceOf(BytesTransportResponse.class)); + var bytesResp = (BytesTransportResponse) sentResponse.get(); + try (StreamInput in = bytesResp.streamInput()) { + var deserialized = new SimpleTestResponse(in); + assertThat(deserialized.value, equalTo("test")); + } finally { + bytesResp.decRef(); + } + } + + public void testNetworkPathSerializationFailureSendsFailure() { + var sentException = new AtomicReference(); + + var channel = new TestTransportChannel( + ActionListener.wrap(resp -> fail("should not succeed when serialization fails"), sentException::set) + ); + + ActionListener listener = SearchTransportService.channelListener( + transportService, + channel, + newLimitedBreaker(ByteSizeValue.ofMb(100)) + ); + + listener.onResponse(new FailingTestResponse()); + + assertThat(sentException.get(), notNullValue()); + assertThat(sentException.get(), instanceOf(IOException.class)); + assertThat(sentException.get().getMessage(), equalTo("simulated serialization failure")); + } + + public void testNetworkPathOnFailureForwardsFailure() { + var sentException = new AtomicReference(); + + var channel = new TestTransportChannel(ActionListener.wrap(resp -> fail("should not succeed"), sentException::set)); + + ActionListener listener = SearchTransportService.channelListener( + transportService, + channel, + newLimitedBreaker(ByteSizeValue.ofMb(100)) + ); + + RuntimeException e = new RuntimeException("upstream failure"); + listener.onFailure(e); + + assertThat(sentException.get(), notNullValue()); + assertThat(sentException.get(), sameInstance(e)); + } + + public void testDirectPathForwardsOriginalResponse() { + var sentResponse = new AtomicReference(); + + var channel = new TestDirectResponseChannel(ActionListener.wrap(sentResponse::set, e -> fail("unexpected failure: " + e))); + + ActionListener listener = SearchTransportService.channelListener( + transportService, + channel, + newLimitedBreaker(ByteSizeValue.ofMb(100)) + ); + + var original = new SimpleTestResponse("direct-test"); + listener.onResponse(original); + + assertSame(original, sentResponse.get()); + assertThat(sentResponse.get(), not(instanceOf(BytesTransportResponse.class))); + } + + public void testDirectPathOnFailureForwardsFailure() { + var sentException = new AtomicReference(); + + var channel = new TestDirectResponseChannel(ActionListener.wrap(resp -> fail("should not succeed"), sentException::set)); + + ActionListener listener = SearchTransportService.channelListener( + transportService, + channel, + newLimitedBreaker(ByteSizeValue.ofMb(100)) + ); + + RuntimeException e = new RuntimeException("upstream failure"); + listener.onFailure(e); + + assertThat(sentException.get(), notNullValue()); + assertThat(sentException.get(), sameInstance(e)); + } + + public void testTaskTransportChannelUnwrapsToDirectPath() { + var sentResponse = new AtomicReference(); + + var directChannel = new TestDirectResponseChannel(ActionListener.wrap(sentResponse::set, e -> fail("unexpected failure: " + e))); + var taskChannel = TestTransportChannels.newTaskTransportChannel(directChannel, () -> {}); + + ActionListener listener = SearchTransportService.channelListener( + transportService, + taskChannel, + newLimitedBreaker(ByteSizeValue.ofMb(100)) + ); + + var original = new SimpleTestResponse("task-wrapped-test"); + listener.onResponse(original); + + assertSame(original, sentResponse.get()); + assertThat(sentResponse.get(), not(instanceOf(BytesTransportResponse.class))); + sentResponse.get().decRef(); + } + + public void testTaskTransportChannelUnwrapsToNetworkPath() { + var sentResponse = new AtomicReference(); + + var networkChannel = new TestTransportChannel(ActionListener.wrap(resp -> { + resp.mustIncRef(); + sentResponse.set(resp); + }, e -> fail("unexpected failure: " + e))); + var taskChannel = TestTransportChannels.newTaskTransportChannel(networkChannel, () -> {}); + + ActionListener listener = SearchTransportService.channelListener( + transportService, + taskChannel, + newLimitedBreaker(ByteSizeValue.ofMb(100)) + ); + + listener.onResponse(new SimpleTestResponse("task-network-test")); + + assertThat(sentResponse.get(), instanceOf(BytesTransportResponse.class)); + sentResponse.get().decRef(); + } + + static class SimpleTestResponse extends TransportResponse { + final String value; + + SimpleTestResponse(String value) { + this.value = value; + } + + SimpleTestResponse(StreamInput in) throws IOException { + this.value = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(value); + } + } + + static class FailingTestResponse extends TransportResponse { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("simulated serialization failure"); + } + } + +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/TestDirectResponseChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/TestDirectResponseChannel.java new file mode 100644 index 0000000000000..505feafd41c3b --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/TestDirectResponseChannel.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.transport; + +import org.elasticsearch.action.ActionListener; + +public class TestDirectResponseChannel extends TransportService.DirectResponseChannel { + + private final ActionListener listener; + + public TestDirectResponseChannel(ActionListener listener) { + super(null, "test", 0, null); + this.listener = listener; + } + + @Override + public void sendResponse(TransportResponse response) { + listener.onResponse(response); + } + + @Override + public void sendResponse(Exception exception) { + listener.onFailure(exception); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/TestTransportChannels.java b/test/framework/src/main/java/org/elasticsearch/transport/TestTransportChannels.java index effabd85591f9..ca627f96ee2dd 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/TestTransportChannels.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/TestTransportChannels.java @@ -12,10 +12,15 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.common.network.HandlingTimeTracker; import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.core.Releasable; import org.elasticsearch.threadpool.ThreadPool; public class TestTransportChannels { + public static TaskTransportChannel newTaskTransportChannel(TransportChannel channel, Releasable onTaskFinished) { + return new TaskTransportChannel(1, channel, onTaskFinished); + } + public static TcpTransportChannel newFakeTcpTransportChannel( String nodeName, TcpChannel channel,