diff --git a/muted-tests.yml b/muted-tests.yml index 4a5feb4eb3662..9e32870c349d4 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -318,9 +318,6 @@ tests: - class: org.elasticsearch.xpack.esql.qa.multi_node.GenerativeIT method: test issue: https://github.com/elastic/elasticsearch/issues/144587 -- class: org.elasticsearch.action.search.TransportSearchIT - method: testCircuitBreakerReduceFail - issue: https://github.com/elastic/elasticsearch/issues/144598 - class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeIT method: test issue: https://github.com/elastic/elasticsearch/issues/144587 diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java index a45237acd01eb..bb8426222f70f 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java @@ -40,6 +40,7 @@ import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.MockSearchService; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; @@ -541,6 +542,7 @@ public void onFailure(Exception exc) { ); } assertBusy(() -> assertThat(requestBreakerUsed(), equalTo(0L))); + assertBusy(MockSearchService::assertNoInFlightContext); } finally { updateClusterSettings( Settings.builder().putNull("indices.breaker.request.limit").putNull(SearchService.BATCHED_QUERY_PHASE.getKey()) @@ -548,6 +550,20 @@ public void onFailure(Exception exc) { } } + public void testReaderContextFreedOnSerializationFailure() throws Exception { + String coordinatingNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + indexSomeDocs("test", 1, 3); + ensureGreen("test"); + + updateClusterSettings(Settings.builder().put("indices.breaker.request.limit", "1b")); + try { + expectThrows(Exception.class, client(coordinatingNode).prepareSearch("test")::get); + assertBusy(MockSearchService::assertNoInFlightContext); + } finally { + updateClusterSettings(Settings.builder().putNull("indices.breaker.request.limit")); + } + } + public void testCircuitBreakerFetchFail() throws Exception { int numShards = randomIntBetween(1, 10); int numDocs = numShards * 10; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index 9774ba54d6b90..0bbaced7da02c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; @@ -738,8 +739,8 @@ public void onResponse(T response) { response.writeTo(out); bytesRef = out.moveToBytesReference(); } catch (Exception e) { - channelListener.onFailure(e); - return; + // Propagate to caller so wrapFailureListener in SearchService can free the reader context. + throw ExceptionsHelper.convertToRuntime(e); } // respondAndRelease releases the bytes once the transport layer completes. ActionListener.respondAndRelease(channelListener, new BytesTransportResponse(bytesRef, transportVersion)); diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 1e21a39baff53..1135412362121 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -156,7 +156,9 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.LongSupplier; import java.util.function.Supplier; @@ -754,11 +756,31 @@ public void executeQueryPhase(ShardSearchRequest request, CancellableTask task, } // TODO: i think it makes sense to always do a canMatch here and // return an empty response (not null response) in case canMatch is false? - ensureAfterSeqNoRefreshed(shard, orig, () -> executeQueryPhase(orig, task), l); + executeQueryPhaseAsync(shard, orig, task, l); }) ); } + private void executeQueryPhaseAsync( + IndexShard shard, + ShardSearchRequest request, + CancellableTask task, + ActionListener listener + ) { + // wrapFailureListener requires readerContext and markAsUsed, but those are created inside the supplier + // lambda below. The ActionListener.wrap callbacks are constructed (before the supplier runs) and must + // therefore read the listener indirectly. completionListenerRef starts as the plain listener so that any + // failure before the supplier runs is still forwarded. Once the supplier sets it to the wrapped version, + // the ActionListener.wrap callbacks will invoke the one that handles readerContext cleanup on failure. + final var completionListenerRef = new AtomicReference<>(listener); + ensureAfterSeqNoRefreshed(shard, request, () -> { + final ReaderContext readerContext = createOrGetReaderContext(request); + final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(request)); + completionListenerRef.set(wrapFailureListener(listener, readerContext, markAsUsed)); + return executeQueryPhase(request, task, readerContext); + }, ActionListener.wrap(result -> completionListenerRef.get().onResponse(result), e -> completionListenerRef.get().onFailure(e))); + } + private void ensureAfterSeqNoRefreshed( IndexShard shard, ShardSearchRequest request, @@ -908,11 +930,10 @@ private static void runAsync( * It is the responsibility of the caller to ensure that the ref count is correctly decremented * when the object is no longer needed. */ - private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, CancellableTask task) throws Exception { - final ReaderContext readerContext = createOrGetReaderContext(request); + private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, CancellableTask task, ReaderContext readerContext) + throws Exception { try ( Releasable scope = tracer.withScope(task); - Releasable ignored = readerContext.markAsUsed(getKeepAlive(request)); SearchContext context = createContext(readerContext, request, task, ResultsType.QUERY, true) ) { tracer.startTrace("executeQueryPhase", Map.of()); @@ -967,7 +988,6 @@ private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, Cancella : new ElasticsearchException(e.getCause()); } logger.trace("Query phase failed", e); - processFailure(readerContext, e); throw e; } } @@ -1651,32 +1671,65 @@ private ActionListener releaseCircuitBreakerOnResponse( ActionListener listener, Function fetchResultExtractor ) { - return ActionListener.wrap(response -> { - try { - listener.onResponse(response); - } finally { - // Release bytes after the response handler completes - FetchSearchResult fetchResult = fetchResultExtractor.apply(response); - if (fetchResult != null) { - fetchResult.releaseCircuitBreakerBytes(circuitBreaker); + return new ActionListener<>() { + @Override + public void onResponse(T response) { + try { + listener.onResponse(response); + } finally { + // Release bytes after the response handler completes, even if it throws. + // Exceptions are intentionally allowed to propagate so that wrapFailureListener + // can observe them and free the reader context via processFailure. + FetchSearchResult fetchResult = fetchResultExtractor.apply(response); + if (fetchResult != null) { + fetchResult.releaseCircuitBreakerBytes(circuitBreaker); + } } } - }, listener::onFailure); + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }; } private ActionListener wrapFailureListener(ActionListener listener, ReaderContext context, Releasable releasable) { + return wrapFailureListener(listener, releasable, e -> processFailure(context, e)); + } + + /** + * Returns a listener that guarantees {@code releasable} is closed and {@code listener} + * is notified, regardless of whether the operation succeeds or fails. + * + * Visible for testing. + */ + static ActionListener wrapFailureListener( + ActionListener listener, + Releasable releasable, + Consumer onFailureCleanup + ) { return new ActionListener<>() { @Override - public void onResponse(T resp) { - Releasables.close(releasable); - listener.onResponse(resp); + public void onResponse(T response) { + try { + listener.onResponse(response); + } finally { + Releasables.close(releasable); + } } @Override - public void onFailure(Exception exc) { - processFailure(context, exc); - Releasables.close(releasable); - listener.onFailure(exc); + public void onFailure(Exception e) { + try { + onFailureCleanup.accept(e); + } finally { + try { + Releasables.close(releasable); + } finally { + listener.onFailure(e); + } + } } }; } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchTransportServiceChannelListenerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchTransportServiceChannelListenerTests.java index 297a48c0ace29..0dc0fd94aa885 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchTransportServiceChannelListenerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchTransportServiceChannelListenerTests.java @@ -90,11 +90,12 @@ public void testNetworkPathBytesResponseRoundTrip() throws Exception { } } - public void testNetworkPathSerializationFailureSendsFailure() { - var sentException = new AtomicReference(); - + public void testNetworkPathSerializationFailurePropagates() { var channel = new TestTransportChannel( - ActionListener.wrap(resp -> fail("should not succeed when serialization fails"), sentException::set) + ActionListener.wrap( + resp -> fail("should not succeed when serialization fails"), + e -> fail("should not send failure to channel; caller handles it") + ) ); ActionListener listener = SearchTransportService.channelListener( @@ -103,11 +104,9 @@ public void testNetworkPathSerializationFailureSendsFailure() { newLimitedBreaker(ByteSizeValue.ofMb(100)) ); - listener.onResponse(new FailingTestResponse()); - - assertThat(sentException.get(), notNullValue()); - assertThat(sentException.get(), instanceOf(IOException.class)); - assertThat(sentException.get().getMessage(), equalTo("simulated serialization failure")); + var ex = expectThrows(RuntimeException.class, () -> listener.onResponse(new FailingTestResponse())); + assertThat(ex.getCause(), instanceOf(IOException.class)); + assertThat(ex.getCause().getMessage(), equalTo("simulated serialization failure")); } public void testNetworkPathOnFailureForwardsFailure() { diff --git a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java index 75e4cb9fdfb5b..0a314ac84136e 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java @@ -82,12 +82,14 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; import java.util.function.Predicate; import static org.elasticsearch.common.Strings.format; import static org.elasticsearch.common.util.concurrent.EsExecutors.DIRECT_EXECUTOR_SERVICE; import static org.elasticsearch.search.SearchService.isExecutorQueuedBeyondPrewarmingFactor; +import static org.elasticsearch.search.SearchService.wrapFailureListener; import static org.elasticsearch.search.SearchService.wrapListenerForErrorHandling; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.not; @@ -319,6 +321,57 @@ public void onFailure(Exception e) { } } + public void testWrapFailureListenerOnResponse() { + var releasableClosed = new AtomicBoolean(false); + var response = new AtomicReference(); + + var wrapped = wrapFailureListener( + ActionListener.wrap(response::set, e -> fail("unexpected failure")), + () -> releasableClosed.set(true), + e -> fail("cleanup must not run on success") + ); + wrapped.onResponse("ok"); + + assertTrue("releasable must be closed after onResponse", releasableClosed.get()); + assertEquals("ok", response.get()); + } + + public void testWrapFailureListenerOnFailure() { + var releasableClosed = new AtomicBoolean(false); + var cleanupRan = new AtomicBoolean(false); + var failure = new AtomicReference(); + var cause = new RuntimeException("search failed"); + + var wrapped = wrapFailureListener( + ActionListener.wrap(r -> fail("unexpected response"), failure::set), + () -> releasableClosed.set(true), + e -> cleanupRan.set(true) + ); + wrapped.onFailure(cause); + + assertTrue("cleanup must run on failure", cleanupRan.get()); + assertTrue("releasable must be closed after onFailure", releasableClosed.get()); + assertSame("original exception must reach the listener", cause, failure.get()); + } + + public void testWrapFailureListenerCleanupThrows() { + var releasableClosed = new AtomicBoolean(false); + var failure = new AtomicReference(); + var cause = new RuntimeException("search failed"); + + var wrapped = wrapFailureListener( + ActionListener.wrap(r -> fail("unexpected response"), failure::set), + () -> releasableClosed.set(true), + e -> { + throw new RuntimeException("cleanup exploded"); + } + ); + + expectThrows(RuntimeException.class, () -> wrapped.onFailure(cause)); + assertTrue("releasable must be closed even when cleanup throws", releasableClosed.get()); + assertSame("listener.onFailure must be called even when cleanup throws", cause, failure.get()); + } + public void testIsExecutorQueuedBeyondPrewarmingFactor() throws InterruptedException { { final String threadPoolName = randomFrom( diff --git a/x-pack/plugin/ql/test-fixtures/src/main/java/org/elasticsearch/xpack/ql/TestUtils.java b/x-pack/plugin/ql/test-fixtures/src/main/java/org/elasticsearch/xpack/ql/TestUtils.java index 3159c6ea41547..1d3df8907ad31 100644 --- a/x-pack/plugin/ql/test-fixtures/src/main/java/org/elasticsearch/xpack/ql/TestUtils.java +++ b/x-pack/plugin/ql/test-fixtures/src/main/java/org/elasticsearch/xpack/ql/TestUtils.java @@ -75,6 +75,7 @@ import static java.util.Collections.emptyMap; import static org.elasticsearch.cluster.ClusterState.VERSION_INTRODUCING_TRANSPORT_VERSIONS; +import static org.elasticsearch.test.ESTestCase.assertBusy; import static org.elasticsearch.test.ESTestCase.between; import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; import static org.elasticsearch.test.ESTestCase.randomBoolean; @@ -174,15 +175,17 @@ public static EsRelation relation() { // Common methods / assertions // - public static void assertNoSearchContexts(RestClient client) throws IOException { - Map stats = searchStats(client); - @SuppressWarnings("unchecked") - Map indicesStats = (Map) stats.get("indices"); - for (String index : indicesStats.keySet()) { - if (index.startsWith(".") == false) { // We are not interested in internal indices - assertEquals(index + " should have no search contexts", 0, getOpenContexts(stats, index)); + public static void assertNoSearchContexts(RestClient client) throws Exception { + assertBusy(() -> { + Map stats = searchStats(client); + @SuppressWarnings("unchecked") + Map indicesStats = (Map) stats.get("indices"); + for (String index : indicesStats.keySet()) { + if (index.startsWith(".") == false) { // We are not interested in internal indices + assertEquals(index + " should have no search contexts", 0, getOpenContexts(stats, index)); + } } - } + }); } public static int getNumberOfSearchContexts(RestClient client, String index) throws IOException {