From cff7ab50111424f24b9bedff6f0cac76074523ce Mon Sep 17 00:00:00 2001 From: "Mark J. Hoy" Date: Fri, 9 Jan 2026 15:33:39 -0500 Subject: [PATCH] Stabilize and Optimize MMR Result Diversification (#140066) * stabilize/optimize MMR result diversification * fix unit tests (cherry picked from commit 88c87b09c837fc0010a3bc1c8f2a4bd7c92e634d) # Conflicts: # muted-tests.yml --- muted-tests.yml | 3 - ...0_mmr_result_diversification_retriever.yml | 4 +- .../ResultDiversification.java | 14 ++-- .../mmr/MMRResultDiversification.java | 64 ++++++++----------- .../mmr/MMRResultDiversificationTests.java | 2 +- 5 files changed, 34 insertions(+), 53 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index 4a0455ccc27c5..b2f37961f8bc9 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -342,9 +342,6 @@ tests: - class: org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceDateTruncBucketWithRoundToTests method: testReductionPlanForTopNWithPushedDownFunctions issue: https://github.com/elastic/elasticsearch/issues/139493 -- class: org.elasticsearch.xpack.security.CoreWithSecurityClientYamlTestSuiteIT - method: test {yaml=search.retrievers/result-diversification/10_mmr_result_diversification_retriever/Test MMR result diversification single index float type} - issue: https://github.com/elastic/elasticsearch/issues/139527 - class: org.elasticsearch.repositories.gcs.GoogleCloudStorageBlobStoreRepositoryTests method: testMultipleSnapshotAndRollback issue: https://github.com/elastic/elasticsearch/issues/139556 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/result-diversification/10_mmr_result_diversification_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/result-diversification/10_mmr_result_diversification_retriever.yml index 5fa8ec26b7bf7..871e97ae8e386 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/result-diversification/10_mmr_result_diversification_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/result-diversification/10_mmr_result_diversification_retriever.yml @@ -389,8 +389,8 @@ teardown: - match: { hits.total.value: 7 } - length: { hits.hits: 3 } - match: { hits.hits.0._source.textbody: "first text" } - - match: { hits.hits.1._source.textbody: "fourth text" } - - match: { hits.hits.2._source.textbody: "fifth text" } + - match: { hits.hits.1._source.textbody: "third text" } + - match: { hits.hits.2._source.textbody: "fourth text" } --- "Test MMR diversification with arbitrary shards": diff --git a/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversification.java b/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversification.java index d70d97eaacb60..a09369050d752 100644 --- a/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversification.java +++ b/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversification.java @@ -28,19 +28,13 @@ protected ResultDiversification(C context) { public abstract RankDoc[] diversify(RankDoc[] docs) throws IOException; - protected float getFloatVectorComparisonScore( + protected float getVectorComparisonScore( VectorSimilarityFunction similarityFunction, VectorData thisDocVector, VectorData comparisonVector ) { - return similarityFunction.compare(thisDocVector.floatVector(), comparisonVector.floatVector()); - } - - protected float getByteVectorComparisonScore( - VectorSimilarityFunction similarityFunction, - VectorData thisDocVector, - VectorData comparisonVector - ) { - return similarityFunction.compare(thisDocVector.byteVector(), comparisonVector.byteVector()); + return thisDocVector.isFloat() + ? similarityFunction.compare(thisDocVector.floatVector(), comparisonVector.floatVector()) + : similarityFunction.compare(thisDocVector.byteVector(), comparisonVector.byteVector()); } } diff --git a/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java b/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java index 861b7e9130a63..99d6f6305173d 100644 --- a/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java +++ b/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java @@ -20,6 +20,7 @@ import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -37,7 +38,8 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException { return docs; } - Map docIdIndexMapping = new HashMap<>(); + // keep the original ranking so we can output in the same order + Map docIdIndexMapping = new LinkedHashMap<>(); for (int i = 0; i < docs.length; i++) { docIdIndexMapping.put(docs[i].rank, i); } @@ -45,12 +47,8 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException { // our chosen DocIDs to keep List selectedDocRanks = new ArrayList<>(); - // test the vector to see if we are using floats or bytes - VectorData firstVec = context.getFieldVector(docs[0].rank); - boolean useFloat = firstVec.isFloat(); - // cache the similarity scores for the query vector vs. searchHits - Map querySimilarity = getQuerySimilarityForDocs(docs, useFloat, context); + Map querySimilarity = getQuerySimilarityForDocs(docs, context); // always add the highest relevant doc to the list selectedDocRanks.add(getHighestRelevantDocRank(docs, querySimilarity)); @@ -73,14 +71,18 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException { continue; } - var cachedScoresForDoc = cachedSimilarities.getOrDefault(docRank, new HashMap<>()); + var cachedScoresForDoc = cachedSimilarities.getOrDefault(docRank, new LinkedHashMap<>()); // compute MMR scores for remaining searchHits - float highestMMRScore = getHighestScoreForSelectedVectors(docRank, context, useFloat, thisDocVector, cachedScoresForDoc); + float highestSimilarityScoreToSelected = getHighestSimilarityScoreToSelectedVectors( + selectedDocRanks, + thisDocVector, + cachedScoresForDoc + ); // compute MMR float querySimilarityScore = querySimilarity.getOrDefault(doc.rank, 0.0f); - float mmr = (context.getLambda() * querySimilarityScore) - ((1 - context.getLambda()) * highestMMRScore); + float mmr = (context.getLambda() * querySimilarityScore) - ((1 - context.getLambda()) * highestSimilarityScoreToSelected); if (mmr > thisMaxMMRScore) { thisMaxMMRScore = mmr; thisMaxMMRDocRank = docRank; @@ -125,39 +127,29 @@ private Integer getHighestRelevantDocRank(RankDoc[] docs, Map qu return highestScoreDoc.rank; } - private float getHighestScoreForSelectedVectors( - int docRank, - MMRResultDiversificationContext context, - boolean useFloat, + private float getHighestSimilarityScoreToSelectedVectors( + List selectedDocRanks, VectorData thisDocVector, Map cachedScoresForDoc ) { - float highestScore = Float.MIN_VALUE; - for (var vec : context.getFieldVectorsEntrySet()) { - if (vec.getKey().equals(docRank)) { - continue; - } - - if (cachedScoresForDoc.containsKey(vec.getKey())) { - float score = cachedScoresForDoc.get(vec.getKey()); - if (score > highestScore) { - highestScore = score; - } - } else { - VectorData comparisonVector = vec.getValue(); - float score = useFloat - ? getFloatVectorComparisonScore(similarityFunction, thisDocVector, comparisonVector) - : getByteVectorComparisonScore(similarityFunction, thisDocVector, comparisonVector); - cachedScoresForDoc.put(vec.getKey(), score); - if (score > highestScore) { - highestScore = score; + float highestScore = Float.NEGATIVE_INFINITY; + for (Integer compareToDocRank : selectedDocRanks) { + Float similarityScore = cachedScoresForDoc.getOrDefault(compareToDocRank, null); + if (similarityScore == null) { + VectorData comparisonVector = context.getFieldVector(compareToDocRank); + if (comparisonVector != null) { + similarityScore = getVectorComparisonScore(similarityFunction, thisDocVector, comparisonVector); + cachedScoresForDoc.put(compareToDocRank, similarityScore); } } + if (similarityScore != null && similarityScore > highestScore) { + highestScore = similarityScore; + } } - return highestScore; + return highestScore == Float.NEGATIVE_INFINITY ? 0.0f : highestScore; } - protected Map getQuerySimilarityForDocs(RankDoc[] docs, boolean useFloat, ResultDiversificationContext context) { + protected Map getQuerySimilarityForDocs(RankDoc[] docs, ResultDiversificationContext context) { Map querySimilarity = new HashMap<>(); VectorData queryVector = context.getQueryVector(); @@ -168,9 +160,7 @@ protected Map getQuerySimilarityForDocs(RankDoc[] docs, boolean for (RankDoc doc : docs) { VectorData vectorData = context.getFieldVector(doc.rank); if (vectorData != null) { - float querySimilarityScore = useFloat - ? getFloatVectorComparisonScore(similarityFunction, vectorData, queryVector) - : getByteVectorComparisonScore(similarityFunction, vectorData, queryVector); + float querySimilarityScore = getVectorComparisonScore(similarityFunction, vectorData, queryVector); querySimilarity.put(doc.rank, querySimilarityScore); } } diff --git a/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java b/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java index 7f8eb3509613d..0e809e38268a8 100644 --- a/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java +++ b/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java @@ -125,7 +125,7 @@ private MMRResultDiversificationContext getRandomByteContext(List expec ) ); - expectedDocIds.addAll(List.of(2, 3, 6)); + expectedDocIds.addAll(List.of(3, 4, 6)); return diversificationContext; }