diff --git a/docs/changelog/132680.yaml b/docs/changelog/132680.yaml new file mode 100644 index 0000000000000..4611fc3ad9e0a --- /dev/null +++ b/docs/changelog/132680.yaml @@ -0,0 +1,5 @@ +pr: 132680 +summary: Add support for per-field weights in simplified RRF retriever syntax +area: Search +type: enhancement +issues: [] diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java index 58e047e149309..bf6d1e6fd0693 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java @@ -39,6 +39,7 @@ public Set getTestFeatures() { LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT, RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT, RRFRetrieverBuilder.WEIGHTED_SUPPORT, + RRFRetrieverBuilder.SIMPLIFIED_WEIGHTED_SUPPORT, LINEAR_RETRIEVER_TOP_LEVEL_NORMALIZER, LinearRetrieverBuilder.MULTI_INDEX_SIMPLIFIED_FORMAT_SUPPORT, RRFRetrieverBuilder.MULTI_INDEX_SIMPLIFIED_FORMAT_SUPPORT diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index 7faad9917661d..0374cc809c8f6 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -30,6 +30,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.XPackPlugin; import org.elasticsearch.xpack.rank.MultiFieldsInnerRetrieverUtils; +import org.elasticsearch.xpack.rank.MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource; import java.io.IOException; import java.util.ArrayList; @@ -46,10 +47,14 @@ * meaning it has a set of child retrievers that each return a set of * top docs that will then be combined and ranked according to the rrf * formula. + * + * Supports both explicit retriever configuration and simplified field-based + * syntax with optional per-field weights (e.g., "field^2.0"). */ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder { public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support"); public static final NodeFeature WEIGHTED_SUPPORT = new NodeFeature("rrf_retriever.weighted_support"); + public static final NodeFeature SIMPLIFIED_WEIGHTED_SUPPORT = new NodeFeature("rrf_retriever.simplified_weighted_support"); public static final NodeFeature MULTI_INDEX_SIMPLIFIED_FORMAT_SUPPORT = new NodeFeature( "rrf_retriever.multi_index_simplified_format_support" ); @@ -265,23 +270,8 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { fields, query, localIndicesMetadata.values(), - r -> { - List retrievers = new ArrayList<>(r.size()); - float[] weights = new float[r.size()]; - for (int i = 0; i < r.size(); i++) { - var retriever = r.get(i); - retrievers.add(retriever.retrieverSource()); - weights[i] = retriever.weight(); - } - return new RRFRetrieverBuilder(retrievers, null, null, rankWindowSize, rankConstant, weights); - }, - w -> { - if (w != 1.0f) { - throw new IllegalArgumentException( - "[" + NAME + "] does not support per-field weights in [" + FIELDS_FIELD.getPreferredName() + "]" - ); - } - } + r -> createRRFFromWeightedRetrievers(r, rankWindowSize, rankConstant), + w -> validateNonNegativeWeight(w) ).stream().map(RetrieverSource::from).toList(); if (fieldsInnerRetrievers.isEmpty() == false) { @@ -295,7 +285,6 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { rewritten = new StandardRetrieverBuilder(new MatchNoneQueryBuilder()); } } - return rewritten; } @@ -340,4 +329,26 @@ public boolean doEquals(Object o) { public int doHashCode() { return Objects.hash(super.doHashCode(), fields, query, rankConstant, Arrays.hashCode(weights)); } + + private static RRFRetrieverBuilder createRRFFromWeightedRetrievers( + List r, + int rankWindowSize, + int rankConstant + ) { + int size = r.size(); + List retrievers = new ArrayList<>(size); + float[] weights = new float[size]; + for (int i = 0; i < size; i++) { + var retriever = r.get(i); + retrievers.add(retriever.retrieverSource()); + weights[i] = retriever.weight(); + } + return new RRFRetrieverBuilder(retrievers, null, null, rankWindowSize, rankConstant, weights); + } + + private static void validateNonNegativeWeight(float w) { + if (w < 0) { + throw new IllegalArgumentException("[" + NAME + "] per-field weights must be non-negative"); + } + } } diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java index f518377c5c636..cb88f8572f432 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java @@ -51,7 +51,14 @@ public static RRFRetrieverBuilder createRandomRRFRetrieverBuilder() { List fields = null; String query = null; if (randomBoolean()) { - fields = randomList(1, 10, () -> randomAlphaOfLengthBetween(1, 10)); + fields = randomList(1, 10, () -> { + String field = randomAlphaOfLengthBetween(1, 10); + if (randomBoolean()) { + float weight = randomFloatBetween(0.0f, 10.1f, true); + field = field + "^" + weight; + } + return field; + }); query = randomAlphaOfLengthBetween(1, 10); } @@ -359,6 +366,36 @@ public void testRRFRetrieverComponentErrorCases() throws IOException { expectParsingException(retrieverAsStringContent, "retriever must be an object"); } + public void testSimplifiedWeightedFieldsParsing() throws IOException { + String restContent = """ + { + "retriever": { + "rrf": { + "retrievers": [ + { + "test": { + "value": "foo" + } + }, + { + "test": { + "value": "bar" + } + } + ], + "fields": ["name^2.0", "description^0.5"], + "query": "test", + "rank_window_size": 100, + "rank_constant": 10, + "min_score": 20.0, + "_name": "foo_rrf" + } + } + } + """; + checkRRFRetrieverParsing(restContent); + } + private void expectParsingException(String restContent, String expectedMessageFragment) throws IOException { SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder(); try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) { diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java index 6a0ad75cc721a..3670c252e0ce9 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java @@ -200,11 +200,29 @@ public void testMultiFieldsParamsRewrite() { Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f), "foo2" ); + } - // Glob matching on inference and non-inference fields - rrfRetrieverBuilder = new RRFRetrieverBuilder( + public void testMultiFieldsParamsRewriteWithWeights() { + final String indexName = "test-index"; + final List testInferenceFields = List.of("semantic_field_1", "semantic_field_2"); + final ResolvedIndices resolvedIndices = createMockResolvedIndices(Map.of(indexName, testInferenceFields), null, Map.of()); + final QueryRewriteContext queryRewriteContext = new QueryRewriteContext( + parserConfig(), + null, null, - List.of("field_*", "*_field_1"), + TransportVersion.current(), + RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, + resolvedIndices, + new PointInTimeBuilder(new BytesArray("pitid")), + null, + null, + false + ); + + // Simple per-field boosting + RRFRetrieverBuilder rrfRetrieverBuilder = new RRFRetrieverBuilder( + null, + List.of("field_1", "field_2^1.5", "semantic_field_1", "semantic_field_2^2"), "bar", DEFAULT_RANK_WINDOW_SIZE, RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT, @@ -213,15 +231,16 @@ public void testMultiFieldsParamsRewrite() { assertMultiFieldsParamsRewrite( rrfRetrieverBuilder, queryRewriteContext, - Map.of("field_*", 1.0f, "*_field_1", 1.0f), - Map.of("semantic_field_1", 1.0f), - "bar" + Map.of("field_1", 1.0f, "field_2", 1.5f), + Map.of("semantic_field_1", 1.0f, "semantic_field_2", 2.0f), + "bar", + null ); - // All-fields wildcard + // Glob matching on inference and non-inference fields with per-field boosting rrfRetrieverBuilder = new RRFRetrieverBuilder( null, - List.of("*"), + List.of("field_*^1.5", "*_field_1^2.5"), "baz", DEFAULT_RANK_WINDOW_SIZE, RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT, @@ -230,10 +249,117 @@ public void testMultiFieldsParamsRewrite() { assertMultiFieldsParamsRewrite( rrfRetrieverBuilder, queryRewriteContext, - Map.of("*", 1.0f), - Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f), - "baz" + Map.of("field_*", 1.5f, "*_field_1", 2.5f), + Map.of("semantic_field_1", 2.5f), + "baz", + null + ); + + // Multiple boosts defined on the same field + rrfRetrieverBuilder = new RRFRetrieverBuilder( + null, + List.of("field_*^1.5", "field_1^3.0", "*_field_1^2.5", "semantic_*^1.5"), + "baz2", + DEFAULT_RANK_WINDOW_SIZE, + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT, + new float[0] + ); + assertMultiFieldsParamsRewrite( + rrfRetrieverBuilder, + queryRewriteContext, + Map.of("field_*", 1.5f, "field_1", 3.0f, "*_field_1", 2.5f, "semantic_*", 1.5f), + Map.of("semantic_field_1", 3.75f, "semantic_field_2", 1.5f), + "baz2", + null + ); + + // All-fields wildcard with weights + rrfRetrieverBuilder = new RRFRetrieverBuilder( + null, + List.of("*^2.0"), + "qux", + DEFAULT_RANK_WINDOW_SIZE, + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT, + new float[0] + ); + assertMultiFieldsParamsRewrite( + rrfRetrieverBuilder, + queryRewriteContext, + Map.of("*", 2.0f), + Map.of("semantic_field_1", 2.0f, "semantic_field_2", 2.0f), + "qux", + null + ); + + // Zero weights (testing that zero is allowed as non-negative) + rrfRetrieverBuilder = new RRFRetrieverBuilder( + null, + List.of("field_1^0", "field_2^1.0"), + "zero_test", + DEFAULT_RANK_WINDOW_SIZE, + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT, + new float[0] + ); + assertMultiFieldsParamsRewrite( + rrfRetrieverBuilder, + queryRewriteContext, + Map.of("field_1", 0.0f, "field_2", 1.0f), + Map.of(), + "zero_test", + null + ); + + // Mixed weighted and unweighted fields in simplified syntax + rrfRetrieverBuilder = new RRFRetrieverBuilder( + null, + List.of("title^2.5", "content", "tags^1.5", "description"), + "test query", + DEFAULT_RANK_WINDOW_SIZE, + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT, + new float[0] + ); + assertMultiFieldsParamsRewrite( + rrfRetrieverBuilder, + queryRewriteContext, + Map.of("title", 2.5f, "content", 1.0f, "tags", 1.5f, "description", 1.0f), + Map.of(), + "test query", + null + ); + + // Decimal weight precision handling + rrfRetrieverBuilder = new RRFRetrieverBuilder( + null, + List.of("field1^0.1", "field2^2.75", "field3^10.999"), + "test query", + DEFAULT_RANK_WINDOW_SIZE, + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT, + new float[0] + ); + assertMultiFieldsParamsRewrite( + rrfRetrieverBuilder, + queryRewriteContext, + Map.of("field1", 0.1f, "field2", 2.75f, "field3", 10.999f), + Map.of(), + "test query", + null + ); + + // Test negative weight validation + RRFRetrieverBuilder negativeWeightBuilder = new RRFRetrieverBuilder( + null, + List.of("field_1^-1.0"), + "negative_test", + DEFAULT_RANK_WINDOW_SIZE, + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT, + new float[0] + ); + + IllegalArgumentException iae = expectThrows( + IllegalArgumentException.class, + () -> negativeWeightBuilder.doRewrite(queryRewriteContext) ); + assertEquals("[rrf] per-field weights must be non-negative", iae.getMessage()); } public void testMultiIndexMultiFieldsParamsRewrite() { diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/310_rrf_retriever_simplified.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/310_rrf_retriever_simplified.yml index 2ae6ad778e030..4a42e5170ca44 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/310_rrf_retriever_simplified.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/310_rrf_retriever_simplified.yml @@ -2,7 +2,7 @@ setup: - requires: cluster_features: [ "rrf_retriever.multi_fields_query_format_support" ] reason: "RRF retriever multi-fields query format support" - test_runner_features: [ "contains" ] + test_runner_features: [ "contains", "headers", "close_to" ] - do: inference.put: @@ -218,7 +218,107 @@ setup: - match: { hits.hits.2._id: "3" } --- -"Per-field boosting is not supported": +"Lexical match per-field boosting using the simplified format": + - requires: + cluster_features: ["rrf_retriever.simplified_weighted_support"] + reason: "Simplified weighted fields syntax support" + + - do: + headers: + Content-Type: application/json + search: + index: test-index + body: + retriever: + rrf: + fields: [ "text_1", "text_2" ] + query: "foo 1 z" + + # Lexical-only match without weights + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + - match: { hits.hits.0._id: "1" } + - gt: { hits.hits.0._score: 0.0 } + - match: { hits.hits.1._id: "3" } + - gt: { hits.hits.1._score: 0.0 } + + - do: + headers: + Content-Type: application/json + search: + index: test-index + body: + retriever: + rrf: + fields: ["text_1", "text_2^3"] + query: "foo 1 z" + + # Lexical-only match with text_2^3 weighting - doc "3" should rank higher due to text_2 boost + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + - match: { hits.hits.0._id: "3" } + - gt: { hits.hits.0._score: 0.0 } + - match: { hits.hits.1._id: "1" } + - gt: { hits.hits.1._score: 0.0 } + +--- +"Semantic match per-field boosting using the simplified format": + - requires: + cluster_features: ["rrf_retriever.simplified_weighted_support"] + reason: "Simplified weighted fields syntax support" + + # The mock inference services generate synthetic vectors that don't accurately represent similarity to non-identical + # input, so it's hard to create a test that produces intuitive results. Instead, we rely on the fact that the inference + # services generate consistent vectors (i.e. same input -> same output) to demonstrate that per-field boosting on + # a semantic_text field can change the result order. + - do: + headers: + Content-Type: application/json + search: + index: test-index + body: + retriever: + rrf: + fields: [ "dense_inference", "sparse_inference" ] + query: "distributed, RESTful, search engine" + + # Semantic-only match, so max RRF score for rank 1 with default rank_constant (60) is 1/(60+1) = 0.01639 + - match: { hits.total.value: 3 } + - length: { hits.hits: 3 } + - match: { hits.hits.0._id: "2" } + - close_to: { hits.hits.0._score: { value: 0.01639, error: 0.0001 } } + - match: { hits.hits.1._id: "3" } + - lt: { hits.hits.1._score: 1.0 } + - match: { hits.hits.2._id: "1" } + - lt: { hits.hits.2._score: 1.0 } + + - do: + headers: + Content-Type: application/json + search: + index: test-index + body: + retriever: + rrf: + fields: [ "dense_inference^3", "sparse_inference" ] + query: "distributed, RESTful, search engine" + + # Semantic-only match with boosted dense_inference field, so max RRF score for rank 1 is still 1/(60+1) = 0.01639 + - match: { hits.total.value: 3 } + - length: { hits.hits: 3 } + - match: { hits.hits.0._id: "3" } + - close_to: { hits.hits.0._score: { value: 0.01639, error: 0.0001 } } + - match: { hits.hits.1._id: "2" } + - lt: { hits.hits.1._score: 1.0 } + - match: { hits.hits.2._id: "1" } + - lt: { hits.hits.2._score: 1.0 } + +--- +"Negative weight validation": + - requires: + cluster_features: ["rrf_retriever.simplified_weighted_support"] + reason: "Simplified weighted fields syntax support" + - do: catch: bad_request search: @@ -226,10 +326,47 @@ setup: body: retriever: rrf: - fields: [ "text_1", "text_2^3" ] + fields: ["text_1^-1"] + query: "foo" + + - match: { error.root_cause.0.reason: "[rrf] per-field weights must be non-negative" } + +--- +"Zero weight handling": + - requires: + cluster_features: ["rrf_retriever.simplified_weighted_support"] + reason: "Simplified weighted fields syntax support" + + - do: + search: + index: test-index + body: + retriever: + rrf: + fields: ["text_1^0", "text_2^1"] + query: "foo" + + - gte: { hits.total.value: 1 } + +--- +"Basic per-field boosting using the simplified format": + - requires: + cluster_features: ["rrf_retriever.simplified_weighted_support"] + reason: "Simplified weighted fields syntax support" + + - do: + search: + index: test-index + body: + retriever: + rrf: + fields: [ "text_1", "text_2^2" ] query: "foo" - - match: { error.root_cause.0.reason: "[rrf] does not support per-field weights in [fields]" } + # With weighted fields, verify basic functionality + - gte: { hits.total.value: 1 } + - length: { hits.hits: 1 } + # Verify that text_2^2 affects ranking (basic smoke test) --- "Can query text fields": @@ -483,6 +620,24 @@ setup: - match: { hits.hits.1._id: "1" } - match: { hits.hits.2._id: "3" } +--- +"Semantic field weighting": + - requires: + cluster_features: ["rrf_retriever.simplified_weighted_support"] + reason: "Simplified weighted fields syntax support" + + - do: + search: + index: test-index + body: + retriever: + rrf: + fields: ["dense_inference^2", "sparse_inference^1.5"] + query: "elasticsearch" + + - match: { hits.total.value: 3 } + - length: { hits.hits: 3 } + --- "Queries multiple indices using default_field":