Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,6 @@ tests:
- class: org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceDateTruncBucketWithRoundToTests
method: testReductionPlanForTopNWithPushedDownFunctions
issue: https://github.com/elastic/elasticsearch/issues/139493
- class: org.elasticsearch.xpack.security.CoreWithSecurityClientYamlTestSuiteIT
method: test {yaml=search.retrievers/result-diversification/10_mmr_result_diversification_retriever/Test MMR result diversification single index float type}
issue: https://github.com/elastic/elasticsearch/issues/139527
- class: org.elasticsearch.repositories.gcs.GoogleCloudStorageBlobStoreRepositoryTests
method: testMultipleSnapshotAndRollback
issue: https://github.com/elastic/elasticsearch/issues/139556
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ teardown:
- match: { hits.total.value: 7 }
- length: { hits.hits: 3 }
- match: { hits.hits.0._source.textbody: "first text" }
- match: { hits.hits.1._source.textbody: "fourth text" }
- match: { hits.hits.2._source.textbody: "fifth text" }
- match: { hits.hits.1._source.textbody: "third text" }
- match: { hits.hits.2._source.textbody: "fourth text" }

---
"Test MMR diversification with arbitrary shards":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,13 @@ protected ResultDiversification(C context) {

public abstract RankDoc[] diversify(RankDoc[] docs) throws IOException;

protected float getFloatVectorComparisonScore(
protected float getVectorComparisonScore(
VectorSimilarityFunction similarityFunction,
VectorData thisDocVector,
VectorData comparisonVector
) {
return similarityFunction.compare(thisDocVector.floatVector(), comparisonVector.floatVector());
}

protected float getByteVectorComparisonScore(
VectorSimilarityFunction similarityFunction,
VectorData thisDocVector,
VectorData comparisonVector
) {
return similarityFunction.compare(thisDocVector.byteVector(), comparisonVector.byteVector());
return thisDocVector.isFloat()
? similarityFunction.compare(thisDocVector.floatVector(), comparisonVector.floatVector())
: similarityFunction.compare(thisDocVector.byteVector(), comparisonVector.byteVector());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

Expand All @@ -37,20 +38,17 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException {
return docs;
}

Map<Integer, Integer> docIdIndexMapping = new HashMap<>();
// keep the original ranking so we can output in the same order
Map<Integer, Integer> docIdIndexMapping = new LinkedHashMap<>();
for (int i = 0; i < docs.length; i++) {
docIdIndexMapping.put(docs[i].rank, i);
}

// our chosen DocIDs to keep
List<Integer> selectedDocRanks = new ArrayList<>();

// test the vector to see if we are using floats or bytes
VectorData firstVec = context.getFieldVector(docs[0].rank);
boolean useFloat = firstVec.isFloat();

// cache the similarity scores for the query vector vs. searchHits
Map<Integer, Float> querySimilarity = getQuerySimilarityForDocs(docs, useFloat, context);
Map<Integer, Float> querySimilarity = getQuerySimilarityForDocs(docs, context);

// always add the highest relevant doc to the list
selectedDocRanks.add(getHighestRelevantDocRank(docs, querySimilarity));
Expand All @@ -73,14 +71,18 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException {
continue;
}

var cachedScoresForDoc = cachedSimilarities.getOrDefault(docRank, new HashMap<>());
var cachedScoresForDoc = cachedSimilarities.getOrDefault(docRank, new LinkedHashMap<>());

// compute MMR scores for remaining searchHits
float highestMMRScore = getHighestScoreForSelectedVectors(docRank, context, useFloat, thisDocVector, cachedScoresForDoc);
float highestSimilarityScoreToSelected = getHighestSimilarityScoreToSelectedVectors(
selectedDocRanks,
thisDocVector,
cachedScoresForDoc
);

// compute MMR
float querySimilarityScore = querySimilarity.getOrDefault(doc.rank, 0.0f);
float mmr = (context.getLambda() * querySimilarityScore) - ((1 - context.getLambda()) * highestMMRScore);
float mmr = (context.getLambda() * querySimilarityScore) - ((1 - context.getLambda()) * highestSimilarityScoreToSelected);
if (mmr > thisMaxMMRScore) {
thisMaxMMRScore = mmr;
thisMaxMMRDocRank = docRank;
Expand Down Expand Up @@ -125,39 +127,29 @@ private Integer getHighestRelevantDocRank(RankDoc[] docs, Map<Integer, Float> qu
return highestScoreDoc.rank;
}

private float getHighestScoreForSelectedVectors(
int docRank,
MMRResultDiversificationContext context,
boolean useFloat,
private float getHighestSimilarityScoreToSelectedVectors(
List<Integer> selectedDocRanks,
VectorData thisDocVector,
Map<Integer, Float> cachedScoresForDoc
) {
float highestScore = Float.MIN_VALUE;
for (var vec : context.getFieldVectorsEntrySet()) {
if (vec.getKey().equals(docRank)) {
continue;
}

if (cachedScoresForDoc.containsKey(vec.getKey())) {
float score = cachedScoresForDoc.get(vec.getKey());
if (score > highestScore) {
highestScore = score;
}
} else {
VectorData comparisonVector = vec.getValue();
float score = useFloat
? getFloatVectorComparisonScore(similarityFunction, thisDocVector, comparisonVector)
: getByteVectorComparisonScore(similarityFunction, thisDocVector, comparisonVector);
cachedScoresForDoc.put(vec.getKey(), score);
if (score > highestScore) {
highestScore = score;
float highestScore = Float.NEGATIVE_INFINITY;
for (Integer compareToDocRank : selectedDocRanks) {
Float similarityScore = cachedScoresForDoc.getOrDefault(compareToDocRank, null);
if (similarityScore == null) {
VectorData comparisonVector = context.getFieldVector(compareToDocRank);
if (comparisonVector != null) {
similarityScore = getVectorComparisonScore(similarityFunction, thisDocVector, comparisonVector);
cachedScoresForDoc.put(compareToDocRank, similarityScore);
}
}
if (similarityScore != null && similarityScore > highestScore) {
highestScore = similarityScore;
}
}
return highestScore;
return highestScore == Float.NEGATIVE_INFINITY ? 0.0f : highestScore;
}

protected Map<Integer, Float> getQuerySimilarityForDocs(RankDoc[] docs, boolean useFloat, ResultDiversificationContext context) {
protected Map<Integer, Float> getQuerySimilarityForDocs(RankDoc[] docs, ResultDiversificationContext context) {
Map<Integer, Float> querySimilarity = new HashMap<>();

VectorData queryVector = context.getQueryVector();
Expand All @@ -168,9 +160,7 @@ protected Map<Integer, Float> getQuerySimilarityForDocs(RankDoc[] docs, boolean
for (RankDoc doc : docs) {
VectorData vectorData = context.getFieldVector(doc.rank);
if (vectorData != null) {
float querySimilarityScore = useFloat
? getFloatVectorComparisonScore(similarityFunction, vectorData, queryVector)
: getByteVectorComparisonScore(similarityFunction, vectorData, queryVector);
float querySimilarityScore = getVectorComparisonScore(similarityFunction, vectorData, queryVector);
querySimilarity.put(doc.rank, querySimilarityScore);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ private MMRResultDiversificationContext getRandomByteContext(List<Integer> expec
)
);

expectedDocIds.addAll(List.of(2, 3, 6));
expectedDocIds.addAll(List.of(3, 4, 6));

return diversificationContext;
}
Expand Down