diff --git a/docs/changelog/144385.yaml b/docs/changelog/144385.yaml new file mode 100644 index 0000000000000..3111337a163b6 --- /dev/null +++ b/docs/changelog/144385.yaml @@ -0,0 +1,6 @@ +area: Search +issues: + - 140495 +pr: 144385 +summary: Fix `ArrayIndexOutOfBoundsException` in fetch phase with partial results +type: bug diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index 0be281ff3568f..ac839bc573a58 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -263,8 +263,10 @@ private static void mergeSuggest( } FetchSearchResult fetchResult = searchResultProvider.fetchResult(); final int index = fetchResult.counterGetAndIncrement(); - assert index < fetchResult.hits().getHits().length - : "not enough hits fetched. index [" + index + "] length: " + fetchResult.hits().getHits().length; + if (index >= fetchResult.hits().getHits().length) { + // the fetch phase on this shard timed out and returned partial results + continue; + } SearchHit hit = fetchResult.hits().getHits()[index]; CompletionSuggestion.Entry.Option suggestOption = suggestionOptions.get(scoreDocIndex - currentOffset); hit.score(shardDoc.score); @@ -316,8 +318,10 @@ private static SearchHits getHits( } FetchSearchResult fetchResult = fetchResultProvider.fetchResult(); final int index = fetchResult.counterGetAndIncrement(); - assert index < fetchResult.hits().getHits().length - : "not enough hits fetched. index [" + index + "] length: " + fetchResult.hits().getHits().length; + if (index >= fetchResult.hits().getHits().length) { + // the fetch phase on this shard timed out and returned partial results + continue; + } SearchHit searchHit = fetchResult.hits().getHits()[index]; searchHit.shard(fetchResult.getSearchShardTarget()); if (shardDoc instanceof RankDoc) { diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java index 4a242f70e8d02..40be05910bf37 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Objects; /** * Given a set of doc ids and an index reader, sorts the docs by id, splits the sorted @@ -93,9 +94,7 @@ public final SearchHit[] iterate( } SearchTimeoutException.handleTimeout(allowPartialResults, shardTarget, querySearchResult); assert allowPartialResults; - SearchHit[] partialSearchHits = new SearchHit[i]; - System.arraycopy(searchHits, 0, partialSearchHits, 0, i); - return partialSearchHits; + return stripNulls(searchHits); } } } catch (SearchTimeoutException e) { @@ -107,6 +106,15 @@ public final SearchHit[] iterate( return searchHits; } + private static SearchHit[] stripNulls(SearchHit[] searchHits) { + for (SearchHit hit : searchHits) { + if (hit == null) { + return Arrays.stream(searchHits).filter(Objects::nonNull).toArray(SearchHit[]::new); + } + } + return searchHits; + } + private static void purgeSearchHits(SearchHit[] searchHits) { for (SearchHit searchHit : searchHits) { if (searchHit != null) { diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java index f40d63b641e2c..ecb1367ff5505 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java @@ -1417,6 +1417,81 @@ public void testFailConsumeAggs() throws Exception { } } + public void testMergeWithPartialFetchResults() { + int nShards = 3; + int hitsPerShard = 5; + AtomicArray queryResults = new AtomicArray<>(nShards); + for (int shardIndex = 0; shardIndex < nShards; shardIndex++) { + SearchShardTarget target = new SearchShardTarget("", new ShardId("", "", shardIndex), null); + QuerySearchResult qsr = new QuerySearchResult(new ShardSearchContextId("", shardIndex), target, null); + ScoreDoc[] scoreDocs = new ScoreDoc[hitsPerShard]; + for (int i = 0; i < hitsPerShard; i++) { + scoreDocs[i] = new ScoreDoc(i, hitsPerShard - i); + } + qsr.topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(hitsPerShard, Relation.EQUAL_TO), scoreDocs), hitsPerShard), null); + qsr.size(hitsPerShard * nShards); + qsr.setShardIndex(shardIndex); + queryResults.set(shardIndex, qsr); + } + try { + TopDocsStats topDocsStats = new TopDocsStats(SearchContext.TRACK_TOTAL_HITS_ACCURATE); + List bufferedTopDocs = new ArrayList<>(); + for (SearchPhaseResult result : queryResults.asList()) { + QuerySearchResult qsr = result.queryResult(); + TopDocsAndMaxScore td = qsr.consumeTopDocs(); + topDocsStats.add(td, qsr.searchTimedOut(), qsr.terminatedEarly()); + SearchPhaseController.setShardIndex(td.topDocs, qsr.getShardIndex()); + bufferedTopDocs.add(td.topDocs); + } + SearchPhaseController.ReducedQueryPhase reducedQueryPhase = SearchPhaseController.reducedQueryPhase( + queryResults.asList(), + InternalAggregations.EMPTY, + bufferedTopDocs, + topDocsStats, + 0, + false, + null + ); + ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); + assertThat(scoreDocs.length, greaterThan(0)); + + AtomicArray fetchResults = new AtomicArray<>(nShards); + for (int shardIndex = 0; shardIndex < nShards; shardIndex++) { + SearchShardTarget target = new SearchShardTarget("", new ShardId("", "", shardIndex), null); + FetchSearchResult fsr = new FetchSearchResult(new ShardSearchContextId("", shardIndex), target); + int shardHitCount = 0; + for (ScoreDoc sd : scoreDocs) { + if (sd.shardIndex == shardIndex) { + shardHitCount++; + } + } + // simulate a fetch timeout: shard 0 returns fewer hits than expected + int fetchedCount = (shardIndex == 0 && shardHitCount > 0) ? shardHitCount - 1 : shardHitCount; + SearchHit[] hits = new SearchHit[fetchedCount]; + int idx = 0; + for (ScoreDoc sd : scoreDocs) { + if (sd.shardIndex == shardIndex && idx < fetchedCount) { + hits[idx++] = SearchHit.unpooled(sd.doc, ""); + } + } + fsr.shardResult(SearchHits.unpooled(hits, new TotalHits(fetchedCount, Relation.EQUAL_TO), Float.NaN), null); + fetchResults.set(shardIndex, fsr); + } + try (SearchResponseSections mergedResponse = SearchPhaseController.merge(false, reducedQueryPhase, fetchResults)) { + // the merged response should not contain more hits than available fetch results + assertThat(mergedResponse.hits().getHits().length, lessThan(scoreDocs.length)); + for (SearchHit hit : mergedResponse.hits().getHits()) { + assertNotNull(hit); + assertNotNull(hit.getShard()); + } + } finally { + fetchResults.asList().forEach(RefCounted::decRef); + } + } finally { + queryResults.asList().forEach(RefCounted::decRef); + } + } + private static class AssertingCircuitBreaker extends NoopCircuitBreaker { private final AtomicBoolean shouldBreak = new AtomicBoolean(false); diff --git a/server/src/test/java/org/elasticsearch/search/fetch/FetchPhaseDocsIteratorTests.java b/server/src/test/java/org/elasticsearch/search/fetch/FetchPhaseDocsIteratorTests.java index c8d1b6721c64b..d8cb5d69f4e08 100644 --- a/server/src/test/java/org/elasticsearch/search/fetch/FetchPhaseDocsIteratorTests.java +++ b/server/src/test/java/org/elasticsearch/search/fetch/FetchPhaseDocsIteratorTests.java @@ -16,7 +16,9 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.elasticsearch.index.cache.query.TrivialQueryCachingPolicy; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.internal.ContextIndexSearcher; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.test.ESTestCase; @@ -28,6 +30,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThan; @@ -137,6 +140,58 @@ protected SearchHit nextDoc(int doc) { directory.close(); } + public void testTimeoutReturnsCompactPartialResults() throws IOException { + int docCount = 400; + Directory directory = newDirectory(); + RandomIndexWriter writer = new RandomIndexWriter(random(), directory); + for (int i = 0; i < docCount; i++) { + Document doc = new Document(); + doc.add(new StringField("field", "foo", Field.Store.NO)); + writer.addDocument(doc); + if (i % 50 == 0) { + writer.commit(); + } + } + writer.commit(); + IndexReader reader = writer.getReader(); + writer.close(); + + ContextIndexSearcher searcher = new ContextIndexSearcher(reader, null, null, TrivialQueryCachingPolicy.NEVER, randomBoolean()); + + // deliberately unsorted doc ids so that the doc-id-sorted iteration order + // differs from the original order + int[] docs = new int[] { 250, 10, 150, 50, 300, 100, 200, 350 }; + // in doc-id order: 10, 50, 100, 150, 200, ... timeout at doc 200 + final int timeoutAfterDocId = 200; + + FetchPhaseDocsIterator it = new FetchPhaseDocsIterator() { + @Override + protected void setNextReader(LeafReaderContext ctx, int[] docsInLeaf) {} + + @Override + protected SearchHit nextDoc(int doc) { + if (doc == timeoutAfterDocId) { + searcher.throwTimeExceededException(); + } + return new SearchHit(doc); + } + }; + + SearchHit[] hits = it.iterate(null, reader, docs, true, new QuerySearchResult()); + + // the returned array is compact — no null entries, shorter than input + assertThat(hits.length, greaterThan(0)); + assertThat(hits.length, lessThan(docs.length)); + for (SearchHit hit : hits) { + assertNotNull(hit); + assertThat(hit.docId(), greaterThanOrEqualTo(0)); + hit.decRef(); + } + + reader.close(); + directory.close(); + } + private static int[] randomDocIds(int maxDoc) { List integers = new ArrayList<>(); int v = 0;