diff --git a/docs/changelog/139752.yaml b/docs/changelog/139752.yaml new file mode 100644 index 0000000000000..1f3cc23eb452e --- /dev/null +++ b/docs/changelog/139752.yaml @@ -0,0 +1,5 @@ +pr: 139752 +summary: Take control of max clause count verification in Lucene searcher +area: Search +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java b/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java index 6a74d6146788b..ae1a6ac0252b3 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java @@ -13,6 +13,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.CollectionStatistics; import org.apache.lucene.search.CollectionTerminatedException; @@ -27,12 +28,14 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryCache; import org.apache.lucene.search.QueryCachingPolicy; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.Weight; import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.automaton.ByteRunAutomaton; import org.elasticsearch.common.lucene.search.BitsIterator; import org.elasticsearch.core.Releasable; import org.elasticsearch.search.dfs.AggregatedDfs; @@ -51,6 +54,7 @@ import java.util.PriorityQueue; import java.util.concurrent.Callable; import java.util.concurrent.Executor; +import java.util.function.Supplier; import java.util.stream.Collectors; /** @@ -204,8 +208,19 @@ public Query rewrite(Query original) throws IOException { if (profiler != null) { rewriteTimer = profiler.startRewriteTime(); } + + /** + * We override rewrite because this is where the superclass checks the max clause count. + * Overriding allows us to customize this limit and take full control by using our own + * visitor to ensure the query does not exceed our allowed limits. + */ try { - return super.rewrite(original); + Query query = original; + for (Query rewrittenQuery = query.rewrite(this); rewrittenQuery != query; rewrittenQuery = query.rewrite(this)) { + query = rewrittenQuery; + } + verifyQueryLimit(query); + return query; } catch (TimeExceededException e) { timeExceeded = true; return REWRITE_TIMEOUT; @@ -594,4 +609,49 @@ public void clear() { runnables.clear(); } } + + /** + * Verifies that the given query does not exceed the maximum allowed clause count. + * Traverses the query upfront to estimate its total cost and fails fast by throwing + * {@link TooManyNestedClauses} if the limit is exceeded. + */ + private static void verifyQueryLimit(Query query) { + final int[] numClauses = new int[1]; + final int maxClauseCount = getMaxClauseCount(); + query.visit(new QueryVisitor() { + @Override + public QueryVisitor getSubVisitor(BooleanClause.Occur occur, Query parent) { + // Return this instance even for MUST_NOT and not an empty QueryVisitor + return this; + } + + @Override + public void visitLeaf(Query query) { + if (numClauses[0] > maxClauseCount) { + throw new TooManyNestedClauses(); + } + ++numClauses[0]; + } + + @Override + public void consumeTerms(Query query, Term... terms) { + if (numClauses[0] > maxClauseCount) { + throw new TooManyNestedClauses(); + } + numClauses[0] += terms.length; + } + + @Override + public void consumeTermsMatching(Query query, String field, Supplier automaton) { + if (numClauses[0] > maxClauseCount) { + throw new TooManyNestedClauses(); + } + ++numClauses[0]; + } + }); + + if (numClauses[0] > maxClauseCount) { + throw new TooManyNestedClauses(); + } + } } diff --git a/server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java b/server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java index f239ae3c273cc..ff844d9adf4af 100644 --- a/server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java +++ b/server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java @@ -15,6 +15,7 @@ import org.apache.lucene.document.Field; import org.apache.lucene.document.IntPoint; import org.apache.lucene.document.StringField; +import org.apache.lucene.document.TextField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.FilterDirectoryReader; import org.apache.lucene.index.FilterLeafReader; @@ -40,6 +41,7 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.PhraseQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.Scorable; @@ -49,6 +51,7 @@ import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHitCountCollectorManager; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; @@ -88,6 +91,7 @@ import static org.elasticsearch.search.internal.ExitableDirectoryReader.ExitablePointValues; import static org.elasticsearch.search.internal.ExitableDirectoryReader.ExitableTerms; import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; @@ -191,6 +195,7 @@ private int indexDocs(Directory directory) throws IOException { for (int i = 0; i < numDocs; i++) { Document document = new Document(); document.add(new StringField("field", "value", Field.Store.NO)); + document.add(new TextField("p_field", "value", Field.Store.NO)); iw.addDocument(document); if (rarely()) { iw.flush(); @@ -612,6 +617,39 @@ public Query rewrite(IndexSearcher indexSearcher) { } } + public void testMaxClause() throws Exception { + try (Directory dir = newDirectory()) { + indexDocs(dir); + ThreadPoolExecutor executor = null; + try (var directoryReader = DirectoryReader.open(dir)) { + if (randomBoolean()) { + executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(randomIntBetween(2, 5)); + } + var searcher = new ContextIndexSearcher( + directoryReader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + executor, + executor == null ? -1 : executor.getMaximumPoolSize(), + 1 + ); + var query = new PhraseQuery.Builder().add(new Term("p_field", "value1")) + .add(new Term("p_field", "value2")) + .add(new Term("p_field", "value")) + .build(); + IndexSearcher.setMaxClauseCount(2); + var exc = expectThrows(IllegalArgumentException.class, () -> searcher.search(query, 10)); + assertThat(exc.getMessage(), containsString("too many clauses")); + IndexSearcher.setMaxClauseCount(3); + var top = searcher.search(query, 10); + assertThat(top.totalHits.value(), equalTo(0L)); + assertThat(top.totalHits.relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + } + } + } + private static class TestQuery extends Query { @Override public String toString(String field) {