diff --git a/CHANGELOG.md b/CHANGELOG.md index 63ddb54ac69fb..e1a6751dcaa57 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add last index request timestamp columns to the `_cat/indices` API. ([10766](https://github.com/opensearch-project/OpenSearch/issues/10766)) - Introduce a new pull-based ingestion plugin for file-based indexing (for local testing) ([#18591](https://github.com/opensearch-project/OpenSearch/pull/18591)) - Add support for search pipeline in search and msearch template ([#18564](https://github.com/opensearch-project/OpenSearch/pull/18564)) +- Add BooleanQuery rewrite moving constant-scoring must clauses to filter clauses ([#18510](https://github.com/opensearch-project/OpenSearch/issues/18510)) ### Changed - Update Subject interface to use CheckedRunnable ([#18570](https://github.com/opensearch-project/OpenSearch/issues/18570)) diff --git a/modules/percolator/src/internalClusterTest/java/org/opensearch/percolator/PercolatorQuerySearchIT.java b/modules/percolator/src/internalClusterTest/java/org/opensearch/percolator/PercolatorQuerySearchIT.java index d03173b6b37fe..b141e4865e04c 100644 --- a/modules/percolator/src/internalClusterTest/java/org/opensearch/percolator/PercolatorQuerySearchIT.java +++ b/modules/percolator/src/internalClusterTest/java/org/opensearch/percolator/PercolatorQuerySearchIT.java @@ -291,43 +291,40 @@ public void testPercolatorRangeQueries() throws Exception { .get(); logger.info("response={}", response); assertHitCount(response, 2); - assertThat(response.getHits().getAt(0).getId(), equalTo("3")); - assertThat(response.getHits().getAt(1).getId(), equalTo("1")); + assertSearchHits(response, "3", "1"); source = BytesReference.bytes(jsonBuilder().startObject().field("field1", 11).endObject()); response = client().prepareSearch().setQuery(new PercolateQueryBuilder("query", source, MediaTypeRegistry.JSON)).get(); assertHitCount(response, 1); - assertThat(response.getHits().getAt(0).getId(), equalTo("1")); + assertSearchHits(response, "1"); // Test double range: source = BytesReference.bytes(jsonBuilder().startObject().field("field2", 12).endObject()); response = client().prepareSearch().setQuery(new PercolateQueryBuilder("query", source, MediaTypeRegistry.JSON)).get(); assertHitCount(response, 2); - assertThat(response.getHits().getAt(0).getId(), equalTo("6")); - assertThat(response.getHits().getAt(1).getId(), equalTo("4")); + assertSearchHits(response, "6", "4"); source = BytesReference.bytes(jsonBuilder().startObject().field("field2", 11).endObject()); response = client().prepareSearch().setQuery(new PercolateQueryBuilder("query", source, MediaTypeRegistry.JSON)).get(); assertHitCount(response, 1); - assertThat(response.getHits().getAt(0).getId(), equalTo("4")); + assertSearchHits(response, "4"); // Test IP range: source = BytesReference.bytes(jsonBuilder().startObject().field("field3", "192.168.1.5").endObject()); response = client().prepareSearch().setQuery(new PercolateQueryBuilder("query", source, MediaTypeRegistry.JSON)).get(); assertHitCount(response, 2); - assertThat(response.getHits().getAt(0).getId(), equalTo("9")); - assertThat(response.getHits().getAt(1).getId(), equalTo("7")); + assertSearchHits(response, "9", "7"); source = BytesReference.bytes(jsonBuilder().startObject().field("field3", "192.168.1.4").endObject()); response = client().prepareSearch().setQuery(new PercolateQueryBuilder("query", source, MediaTypeRegistry.JSON)).get(); assertHitCount(response, 1); - assertThat(response.getHits().getAt(0).getId(), equalTo("7")); + assertSearchHits(response, "7"); // Test date range: source = BytesReference.bytes(jsonBuilder().startObject().field("field4", "2016-05-15").endObject()); response = client().prepareSearch().setQuery(new PercolateQueryBuilder("query", source, MediaTypeRegistry.JSON)).get(); assertHitCount(response, 1); - assertThat(response.getHits().getAt(0).getId(), equalTo("10")); + assertSearchHits(response, "10"); } public void testPercolatorGeoQueries() throws Exception { diff --git a/server/src/internalClusterTest/java/org/opensearch/search/query/BooleanQueryIT.java b/server/src/internalClusterTest/java/org/opensearch/search/query/BooleanQueryIT.java index ff55889e041d5..0a1255004499f 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/query/BooleanQueryIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/query/BooleanQueryIT.java @@ -229,6 +229,26 @@ public void testMustNotRangeRewriteWithMoreThanOneValue() throws Exception { assertHitCount(client().prepareSearch().setQuery(matchAllQuery()).get(), numDocs); } + public void testMustToFilterRewrite() throws Exception { + // Check we still get expected behavior after rewriting must clauses --> filter clauses. + String intField = "int_field"; + createIndex("test"); + int numDocs = 100; + + for (int i = 0; i < numDocs; i++) { + client().prepareIndex("test").setId(Integer.toString(i)).setSource(intField, i).get(); + } + ensureGreen(); + waitForRelocation(); + forceMerge(); + refresh(); + + int gt = 22; + int lt = 92; + int expectedHitCount = lt - gt - 1; + assertHitCount(client().prepareSearch().setQuery(boolQuery().must(rangeQuery(intField).lt(lt).gt(gt))).get(), expectedHitCount); + } + private String padZeros(int value, int length) { // String.format() not allowed String ret = Integer.toString(value); diff --git a/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java index 85b555cc612ab..82f24b6288cde 100644 --- a/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java @@ -49,6 +49,8 @@ import org.opensearch.core.xcontent.ObjectParser; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.NumberFieldMapper; import java.io.IOException; import java.util.ArrayList; @@ -401,6 +403,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws } changed |= rewriteMustNotRangeClausesToShould(newBuilder, queryRewriteContext); + changed |= rewriteMustClausesToFilter(newBuilder, queryRewriteContext); if (changed) { newBuilder.adjustPureNegative = adjustPureNegative; @@ -550,4 +553,53 @@ private boolean checkAllDocsHaveOneValue(List contexts, Strin } return true; } + + private boolean rewriteMustClausesToFilter(BoolQueryBuilder newBuilder, QueryRewriteContext queryRewriteContext) { + // If we have must clauses which return the same score for all matching documents, like numeric term queries or ranges, + // moving them from must clauses to filter clauses improves performance in some cases. + // This works because it can let Lucene use MaxScoreCache to skip non-competitive docs. + boolean changed = false; + Set mustClausesToMove = new HashSet<>(); + + QueryShardContext shardContext; + if (queryRewriteContext == null) { + shardContext = null; + } else { + shardContext = queryRewriteContext.convertToShardContext(); // can still be null + } + + for (QueryBuilder clause : mustClauses) { + if (isClauseIrrelevantToScoring(clause, shardContext)) { + mustClausesToMove.add(clause); + changed = true; + } + } + + newBuilder.mustClauses.removeAll(mustClausesToMove); + newBuilder.filterClauses.addAll(mustClausesToMove); + return changed; + } + + private boolean isClauseIrrelevantToScoring(QueryBuilder clause, QueryShardContext context) { + // This is an incomplete list of clauses this might apply for; it can be expanded in future. + + // If a clause is purely numeric, for example a date range, its score is unimportant as + // it'll be the same for all returned docs + if (clause instanceof RangeQueryBuilder) return true; + if (clause instanceof GeoBoundingBoxQueryBuilder) return true; + + // Further optimizations depend on knowing whether the field is numeric. + // QueryBuilder.doRewrite() is called several times in the search flow, and the shard context telling us this + // is only available the last time, when it's called from SearchService.executeQueryPhase(). + // Skip moving these clauses if we don't have the shard context. + if (context == null) return false; + if (!(clause instanceof WithFieldName wfn)) return false; + MappedFieldType fieldType = context.fieldMapper(wfn.fieldName()); + if (!(fieldType instanceof NumberFieldMapper.NumberFieldType)) return false; + + if (clause instanceof MatchQueryBuilder) return true; + if (clause instanceof TermQueryBuilder) return true; + if (clause instanceof TermsQueryBuilder) return true; + return false; + } } diff --git a/server/src/test/java/org/opensearch/index/query/BoolQueryBuilderTests.java b/server/src/test/java/org/opensearch/index/query/BoolQueryBuilderTests.java index f08b7786c22af..c9e988117ee9f 100644 --- a/server/src/test/java/org/opensearch/index/query/BoolQueryBuilderTests.java +++ b/server/src/test/java/org/opensearch/index/query/BoolQueryBuilderTests.java @@ -73,6 +73,7 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class BoolQueryBuilderTests extends AbstractQueryTestCase { @Override @@ -630,6 +631,65 @@ public void testMustNotRewriteDisabledWithoutExactlyOneValuePerDoc() throws Exce IOUtils.close(w, reader, dir); } + public void testMustClausesRewritten() throws Exception { + BoolQueryBuilder qb = new BoolQueryBuilder(); + + // Should be moved + QueryBuilder intTermQuery = new TermQueryBuilder(INT_FIELD_NAME, 200); + QueryBuilder rangeQuery = new RangeQueryBuilder(INT_FIELD_NAME).gt(10).lt(20); + // Should be moved to filter clause, the boost applies equally to all matched docs + QueryBuilder rangeQueryWithBoost = new RangeQueryBuilder(DATE_FIELD_NAME).gt(10).lt(20).boost(2); + QueryBuilder intTermsQuery = new TermsQueryBuilder(INT_FIELD_NAME, new int[] { 1, 4, 100 }); + QueryBuilder boundingBoxQuery = new GeoBoundingBoxQueryBuilder(GEO_POINT_FIELD_NAME); + QueryBuilder doubleMatchQuery = new MatchQueryBuilder(DOUBLE_FIELD_NAME, 5.5); + + // Should not be moved + QueryBuilder textTermQuery = new TermQueryBuilder(TEXT_FIELD_NAME, "bar"); + QueryBuilder textTermsQuery = new TermsQueryBuilder(TEXT_FIELD_NAME, "foo", "bar"); + QueryBuilder textMatchQuery = new MatchQueryBuilder(TEXT_FIELD_NAME, "baz"); + + qb.must(intTermQuery); + qb.must(rangeQuery); + qb.must(rangeQueryWithBoost); + qb.must(intTermsQuery); + qb.must(boundingBoxQuery); + qb.must(doubleMatchQuery); + + qb.must(textTermQuery); + qb.must(textTermsQuery); + qb.must(textMatchQuery); + + BoolQueryBuilder rewritten = (BoolQueryBuilder) Rewriteable.rewrite(qb, createShardContext()); + for (QueryBuilder clause : List.of( + intTermQuery, + rangeQuery, + rangeQueryWithBoost, + intTermsQuery, + boundingBoxQuery, + doubleMatchQuery + )) { + assertFalse(rewritten.must().contains(clause)); + assertTrue(rewritten.filter().contains(clause)); + } + for (QueryBuilder clause : List.of(textTermQuery, textTermsQuery, textMatchQuery)) { + assertTrue(rewritten.must().contains(clause)); + assertFalse(rewritten.filter().contains(clause)); + } + + // If we have null QueryShardContext, match/term/terms queries should not be moved as we can't determine if they're numeric. + QueryRewriteContext nullContext = mock(QueryRewriteContext.class); + when(nullContext.convertToShardContext()).thenReturn(null); + rewritten = (BoolQueryBuilder) Rewriteable.rewrite(qb, nullContext); + for (QueryBuilder clause : List.of(rangeQuery, rangeQueryWithBoost, boundingBoxQuery)) { + assertFalse(rewritten.must().contains(clause)); + assertTrue(rewritten.filter().contains(clause)); + } + for (QueryBuilder clause : List.of(textTermQuery, textTermsQuery, textMatchQuery, intTermQuery, intTermsQuery, doubleMatchQuery)) { + assertTrue(rewritten.must().contains(clause)); + assertFalse(rewritten.filter().contains(clause)); + } + } + private QueryBuilder getRangeQueryBuilder(String fieldName, Integer lower, Integer upper, boolean includeLower, boolean includeUpper) { RangeQueryBuilder rq = new RangeQueryBuilder(fieldName); if (lower != null) { diff --git a/server/src/test/java/org/opensearch/index/query/WrapperQueryBuilderTests.java b/server/src/test/java/org/opensearch/index/query/WrapperQueryBuilderTests.java index 1786517c1aa1d..286c1487b2cd8 100644 --- a/server/src/test/java/org/opensearch/index/query/WrapperQueryBuilderTests.java +++ b/server/src/test/java/org/opensearch/index/query/WrapperQueryBuilderTests.java @@ -93,7 +93,8 @@ protected WrapperQueryBuilder doCreateTestQueryBuilder() { @Override protected void doAssertLuceneQuery(WrapperQueryBuilder queryBuilder, Query query, QueryShardContext context) throws IOException { - QueryBuilder innerQuery = queryBuilder.rewrite(createShardContext()); + // Must rewrite recursively so innerQuery matches query + QueryBuilder innerQuery = Rewriteable.rewrite(queryBuilder, createShardContext()); Query expected = rewrite(innerQuery.toQuery(context)); assertEquals(rewrite(query), expected); }