diff --git a/docs/changelog/138002.yaml b/docs/changelog/138002.yaml new file mode 100644 index 0000000000000..6fad9db993d9f --- /dev/null +++ b/docs/changelog/138002.yaml @@ -0,0 +1,5 @@ +pr: 138002 +summary: Fix `SearchContext` CB memory accounting +area: Aggregations +type: bug +issues: [] diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/metrics/LargeTopHitsIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/metrics/LargeTopHitsIT.java new file mode 100644 index 0000000000000..9d446104fea6b --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/metrics/LargeTopHitsIT.java @@ -0,0 +1,172 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.search.aggregations.metrics; + +import org.apache.logging.log4j.util.Strings; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregatorFactory.ExecutionMode; +import org.elasticsearch.search.sort.SortBuilders; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.test.ESIntegTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.elasticsearch.search.aggregations.AggregationBuilders.terms; +import static org.elasticsearch.search.aggregations.AggregationBuilders.topHits; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertFailures; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse; +import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.notNullValue; + +@ESIntegTestCase.SuiteScopeTestCase() +public class LargeTopHitsIT extends ESIntegTestCase { + + private static final String TERMS_AGGS_FIELD_1 = "terms1"; + private static final String TERMS_AGGS_FIELD_2 = "terms2"; + private static final String TERMS_AGGS_FIELD_3 = "terms3"; + private static final String SORT_FIELD = "sort"; + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings)).put("indices.breaker.request.type", "memory").build(); + } + + public static String randomExecutionHint() { + return randomBoolean() ? null : randomFrom(ExecutionMode.values()).toString(); + } + + @Override + public void setupSuiteScopeCluster() throws Exception { + initSmallIdx(); + ensureSearchable(); + } + + private void initSmallIdx() throws IOException { + createIndex("small_idx"); + ensureGreen("small_idx"); + populateIndex("small_idx", 5, 40_000); + } + + private void initLargeIdx() throws IOException { + createIndex("large_idx"); + ensureGreen("large_idx"); + populateIndex("large_idx", 70, 50_000); + } + + public void testSimple() { + assertNoFailuresAndResponse(query("small_idx"), response -> { + Terms terms = response.getAggregations().get("terms"); + assertThat(terms, notNullValue()); + }); + } + + public void test500Queries() { + for (int i = 0; i < 500; i++) { + // make sure we are not leaking memory over multiple queries + assertNoFailuresAndResponse(query("small_idx"), response -> { + Terms terms = response.getAggregations().get("terms"); + assertThat(terms, notNullValue()); + }); + } + } + + // This works most of the time, but it's not consistent: it still triggers OOM sometimes. + // The test env is too small and non-deterministic to hold all these data and results. + @AwaitsFix(bugUrl = "see comment above") + public void testBreakAndRecover() throws IOException { + initLargeIdx(); + assertNoFailuresAndResponse(query("small_idx"), response -> { + Terms terms = response.getAggregations().get("terms"); + assertThat(terms, notNullValue()); + }); + + assertFailures(query("large_idx"), RestStatus.TOO_MANY_REQUESTS, containsString("Data too large")); + + assertNoFailuresAndResponse(query("small_idx"), response -> { + Terms terms = response.getAggregations().get("terms"); + assertThat(terms, notNullValue()); + }); + } + + private void createIndex(String idxName) { + assertAcked( + prepareCreate(idxName).setMapping( + TERMS_AGGS_FIELD_1, + "type=keyword", + TERMS_AGGS_FIELD_2, + "type=keyword", + TERMS_AGGS_FIELD_3, + "type=keyword", + "text", + "type=text,store=true", + "large_text_1", + "type=text,store=false", + "large_text_2", + "type=text,store=false", + "large_text_3", + "type=text,store=false", + "large_text_4", + "type=text,store=false", + "large_text_5", + "type=text,store=false" + ) + ); + } + + private void populateIndex(String idxName, int nDocs, int size) throws IOException { + for (int i = 0; i < nDocs; i++) { + List builders = new ArrayList<>(); + builders.add( + prepareIndex(idxName).setId(Integer.toString(i)) + .setSource( + jsonBuilder().startObject() + .field(TERMS_AGGS_FIELD_1, "val" + i % 53) + .field(TERMS_AGGS_FIELD_2, "val" + i % 23) + .field(TERMS_AGGS_FIELD_3, "val" + i % 10) + .field(SORT_FIELD, i) + .field("text", "some text to entertain") + .field("large_text_1", Strings.repeat("this is a text field 1 ", size)) + .field("large_text_2", Strings.repeat("this is a text field 2 ", size)) + .field("large_text_3", Strings.repeat("this is a text field 3 ", size)) + .field("large_text_4", Strings.repeat("this is a text field 4 ", size)) + .field("large_text_5", Strings.repeat("this is a text field 5 ", size)) + .field("field1", 5) + .field("field2", 2.71) + .endObject() + ) + ); + + indexRandom(true, builders); + } + } + + private static SearchRequestBuilder query(String indexName) { + return prepareSearch(indexName).addAggregation( + terms("terms").executionHint(randomExecutionHint()) + .field(TERMS_AGGS_FIELD_1) + .subAggregation( + terms("terms").executionHint(randomExecutionHint()) + .field(TERMS_AGGS_FIELD_2) + .subAggregation( + terms("terms").executionHint(randomExecutionHint()) + .field(TERMS_AGGS_FIELD_2) + .subAggregation(topHits("hits").sort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC))) + ) + ) + ); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java index 0d21f09e699b5..8ae4555479317 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java @@ -54,6 +54,7 @@ import java.util.List; import java.util.Map; import java.util.function.BiConsumer; +import java.util.function.IntConsumer; class TopHitsAggregator extends MetricsAggregator { @@ -198,7 +199,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) throws IOE for (int i = 0; i < topDocs.scoreDocs.length; i++) { docIdsToLoad[i] = topDocs.scoreDocs[i].doc; } - FetchSearchResult fetchResult = runFetchPhase(subSearchContext, docIdsToLoad); + FetchSearchResult fetchResult = runFetchPhase(subSearchContext, docIdsToLoad, this::addRequestCircuitBreakerBytes); if (fetchProfiles != null) { fetchProfiles.add(fetchResult.profileResult()); } @@ -222,7 +223,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) throws IOE ); } - private static FetchSearchResult runFetchPhase(SubSearchContext subSearchContext, int[] docIdsToLoad) { + private static FetchSearchResult runFetchPhase(SubSearchContext subSearchContext, int[] docIdsToLoad, IntConsumer memoryChecker) { // Fork the search execution context for each slice, because the fetch phase does not support concurrent execution yet. SearchExecutionContext searchExecutionContext = new SearchExecutionContext(subSearchContext.getSearchExecutionContext()); // InnerHitSubContext is not thread-safe, so we fork it as well to support concurrent execution @@ -242,7 +243,7 @@ public InnerHitsContext innerHits() { } }; - fetchSubSearchContext.fetchPhase().execute(fetchSubSearchContext, docIdsToLoad, null); + fetchSubSearchContext.fetchPhase().execute(fetchSubSearchContext, docIdsToLoad, null, memoryChecker); return fetchSubSearchContext.fetchResult(); } diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java index aa27e7d2f0c82..09323f3a2ca08 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java @@ -14,6 +14,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.TotalHits; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.fieldvisitor.LeafStoredFieldLoader; import org.elasticsearch.index.fieldvisitor.StoredFieldLoader; import org.elasticsearch.index.mapper.IdLoader; @@ -47,6 +48,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.function.IntConsumer; import java.util.function.Supplier; import static org.elasticsearch.index.get.ShardGetService.maybeExcludeVectorFields; @@ -67,6 +69,17 @@ public FetchPhase(List fetchSubPhases) { } public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo rankDocs) { + execute(context, docIdsToLoad, rankDocs, null); + } + + /** + * + * @param context + * @param docIdsToLoad + * @param rankDocs + * @param memoryChecker if not provided, the fetch phase will use the circuit breaker to check memory usage + */ + public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo rankDocs, @Nullable IntConsumer memoryChecker) { if (LOGGER.isTraceEnabled()) { LOGGER.trace("{}", new SearchContextSourcePrinter(context)); } @@ -88,7 +101,7 @@ public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo : Profilers.startProfilingFetchPhase(); SearchHits hits = null; try { - hits = buildSearchHits(context, docIdsToLoad, profiler, rankDocs); + hits = buildSearchHits(context, docIdsToLoad, profiler, rankDocs, memoryChecker); } finally { try { // Always finish profiling @@ -116,7 +129,13 @@ public Source getSource(LeafReaderContext ctx, int doc) { } } - private SearchHits buildSearchHits(SearchContext context, int[] docIdsToLoad, Profiler profiler, RankDocShardInfo rankDocs) { + private SearchHits buildSearchHits( + SearchContext context, + int[] docIdsToLoad, + Profiler profiler, + RankDocShardInfo rankDocs, + IntConsumer memoryChecker + ) { var lookup = context.getSearchExecutionContext().getMappingLookup(); // Optionally remove sparse and dense vector fields early to: @@ -180,6 +199,14 @@ private SearchHits buildSearchHits(SearchContext context, int[] docIdsToLoad, Pr SourceLoader.Leaf leafSourceLoader; IdLoader.Leaf leafIdLoader; + IntConsumer memChecker = memoryChecker != null ? memoryChecker : bytes -> { + locallyAccumulatedBytes[0] += bytes; + if (context.checkCircuitBreaker(locallyAccumulatedBytes[0], "fetch source")) { + addRequestBreakerBytes(locallyAccumulatedBytes[0]); + locallyAccumulatedBytes[0] = 0; + } + }; + @Override protected void setNextReader(LeafReaderContext ctx, int[] docsInLeaf) throws IOException { Timer timer = profiler.startNextReader(); @@ -206,10 +233,6 @@ protected SearchHit nextDoc(int doc) throws IOException { if (context.isCancelled()) { throw new TaskCancelledException("cancelled"); } - if (context.checkRealMemoryCB(locallyAccumulatedBytes[0], "fetch source")) { - // if we checked the real memory breaker, we restart our local accounting - locallyAccumulatedBytes[0] = 0; - } HitContext hit = prepareHitContext( context, @@ -233,7 +256,9 @@ protected SearchHit nextDoc(int doc) throws IOException { BytesReference sourceRef = hit.hit().getSourceRef(); if (sourceRef != null) { - locallyAccumulatedBytes[0] += sourceRef.length(); + // This is an empirical value that seems to work well. + // Deserializing a large source would also mean serializing it to HTTP response later on, so x2 seems reasonable + memChecker.accept(sourceRef.length() * 2); } success = true; return hit.hit(); @@ -245,24 +270,31 @@ protected SearchHit nextDoc(int doc) throws IOException { } }; - SearchHit[] hits = docsIterator.iterate( - context.shardTarget(), - context.searcher().getIndexReader(), - docIdsToLoad, - context.request().allowPartialSearchResults(), - context.queryResult() - ); + try { + SearchHit[] hits = docsIterator.iterate( + context.shardTarget(), + context.searcher().getIndexReader(), + docIdsToLoad, + context.request().allowPartialSearchResults(), + context.queryResult() + ); - if (context.isCancelled()) { - for (SearchHit hit : hits) { - // release all hits that would otherwise become owned and eventually released by SearchHits below - hit.decRef(); + if (context.isCancelled()) { + for (SearchHit hit : hits) { + // release all hits that would otherwise become owned and eventually released by SearchHits below + hit.decRef(); + } + throw new TaskCancelledException("cancelled"); } - throw new TaskCancelledException("cancelled"); - } - TotalHits totalHits = context.getTotalHits(); - return new SearchHits(hits, totalHits, context.getMaxScore()); + TotalHits totalHits = context.getTotalHits(); + return new SearchHits(hits, totalHits, context.getMaxScore()); + } finally { + long bytes = docsIterator.getRequestBreakerBytes(); + if (bytes > 0L) { + context.circuitBreaker().addWithoutBreaking(-bytes); + } + } } List getProcessors(SearchShardTarget target, FetchContext context, Profiler profiler) { diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java index 552f9e339bd7c..df29b4d1fad88 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java @@ -32,6 +32,20 @@ */ abstract class FetchPhaseDocsIterator { + /** + * Accounts for FetchPhase memory usage. + * It gets cleaned up after each fetch phase and should not be accessed/modified by subclasses. + */ + private long requestBreakerBytes; + + public void addRequestBreakerBytes(long delta) { + requestBreakerBytes += delta; + } + + public long getRequestBreakerBytes() { + return requestBreakerBytes; + } + /** * Called when a new leaf reader is reached * @param ctx the leaf reader for this set of doc ids diff --git a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java index 7d018a7ef4ba9..16cfc177aefa2 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java @@ -386,15 +386,15 @@ public Query rewrittenQuery() { public abstract long memAccountingBufferSize(); /** - * Checks if the accumulated bytes are greater than the buffer size and if so, checks the available memory in the parent breaker - * (the real memory breaker). + * Checks if the accumulated bytes are greater than the buffer size and if so, checks the circuit breaker. + * IMPORTANT: the caller is responsible for cleaning up the circuit breaker. * @param locallyAccumulatedBytes the number of bytes accumulated locally * @param label the label to use in the breaker - * @return true if the real memory breaker is called and false otherwise + * @return true if the circuit breaker is called and false otherwise */ - public final boolean checkRealMemoryCB(int locallyAccumulatedBytes, String label) { + public final boolean checkCircuitBreaker(int locallyAccumulatedBytes, String label) { if (locallyAccumulatedBytes >= memAccountingBufferSize()) { - circuitBreaker().addEstimateBytesAndMaybeBreak(0, label); + circuitBreaker().addEstimateBytesAndMaybeBreak(locallyAccumulatedBytes, label); return true; } return false; diff --git a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java index b6ca12368f762..2c74b5e6dd40a 100644 --- a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java @@ -867,8 +867,13 @@ public StoredFieldsSpec storedFieldsSpec() { return StoredFieldsSpec.NEEDS_SOURCE; } })); - fetchPhase.execute(searchContext, IntStream.range(0, 100).toArray(), null); - assertThat(breakerCalledCount.get(), is(4)); + fetchPhase.execute( + searchContext, + IntStream.range(0, 100).toArray(), + null, + i -> breakingCircuitBreaker.addEstimateBytesAndMaybeBreak(i, "test") + ); + assertThat(breakerCalledCount.get(), is(100)); } finally { r.close(); dir.close(); diff --git a/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java b/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java index 316a5395ce1c4..3bc2617f6d2d5 100644 --- a/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java +++ b/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java @@ -976,14 +976,14 @@ public void testGetFieldCardinalityRuntimeField() { assertEquals(-1, DefaultSearchContext.getFieldCardinality("field", indexService, null)); } - public void testCheckRealMemoryCB() throws Exception { + public void testCheckCircuitBreaker() throws Exception { IndexShard indexShard = null; try (DefaultSearchContext context = createDefaultSearchContext(Settings.EMPTY)) { indexShard = context.indexShard(); // allocated more than the 1MiB buffer - assertThat(context.checkRealMemoryCB(1024 * 1800, "test"), is(true)); + assertThat(context.checkCircuitBreaker(1024 * 1800, "test"), is(true)); // allocated less than the 1MiB buffer - assertThat(context.checkRealMemoryCB(1024 * 5, "test"), is(false)); + assertThat(context.checkCircuitBreaker(1024 * 5, "test"), is(false)); } finally { if (indexShard != null) { indexShard.getThreadPool().shutdown();