-
Notifications
You must be signed in to change notification settings - Fork 25.9k
ES|QL: Optimize MMR by reducing cache size and lookup #145014
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
58a1be0
a04be0b
e75f379
7e34522
3bed78a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -16,11 +16,11 @@ | |||||||||
|
|
||||||||||
| import java.io.IOException; | ||||||||||
| import java.util.ArrayList; | ||||||||||
| import java.util.Comparator; | ||||||||||
| import java.util.HashMap; | ||||||||||
| import java.util.Arrays; | ||||||||||
| import java.util.LinkedHashMap; | ||||||||||
| import java.util.List; | ||||||||||
| import java.util.Map; | ||||||||||
| import java.util.stream.IntStream; | ||||||||||
|
|
||||||||||
| public class MMRResultDiversification extends ResultDiversification<MMRResultDiversificationContext> { | ||||||||||
|
|
||||||||||
|
|
@@ -46,14 +46,17 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException { | |||||||||
| List<Integer> selectedDocRanks = new ArrayList<>(); | ||||||||||
|
|
||||||||||
| // cache the similarity scores for the query vector vs. searchHits | ||||||||||
| Map<Integer, Float> querySimilarity = getQuerySimilarityForDocs(docs); | ||||||||||
|
|
||||||||||
| float[] querySimilarities = getQuerySimilarityForDocs(docs); | ||||||||||
| // always add the highest relevant doc to the list | ||||||||||
| selectedDocRanks.add(getHighestRelevantDocRank(docs, querySimilarity)); | ||||||||||
| int prevSelectedDocRank = 1 + IntStream.range(0, querySimilarities.length) | ||||||||||
| .reduce(0, (a, b) -> querySimilarities[a] >= querySimilarities[b] ? a : b); | ||||||||||
|
|
||||||||||
| Map<Integer, Map<Integer, Float>> cachedSimilarities = new HashMap<>(); | ||||||||||
| selectedDocRanks.add(prevSelectedDocRank); | ||||||||||
| int topDocsSize = context.getSize(); | ||||||||||
|
|
||||||||||
| float[] maxSimilarityToSelected = new float[docs.length]; | ||||||||||
| Arrays.fill(maxSimilarityToSelected, Float.NEGATIVE_INFINITY); | ||||||||||
|
|
||||||||||
| for (int x = 0; x < topDocsSize && selectedDocRanks.size() < topDocsSize && selectedDocRanks.size() < docs.length; x++) { | ||||||||||
| int thisMaxMMRDocRank = -1; | ||||||||||
| float thisMaxMMRScore = Float.NEGATIVE_INFINITY; | ||||||||||
|
|
@@ -69,29 +72,26 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException { | |||||||||
| continue; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| var cachedScoresForDoc = cachedSimilarities.getOrDefault(docRank, new LinkedHashMap<>()); | ||||||||||
|
|
||||||||||
| // compute MMR scores for remaining searchHits | ||||||||||
| float highestSimilarityScoreToSelected = getHighestSimilarityScoreToSelectedVectors( | ||||||||||
| selectedDocRanks, | ||||||||||
| float similarityToLastSelected = getVectorComparisonScore( | ||||||||||
| similarityFunction, | ||||||||||
| thisDocVector, | ||||||||||
| cachedScoresForDoc | ||||||||||
| context.getFieldVector(prevSelectedDocRank) | ||||||||||
| ); | ||||||||||
| maxSimilarityToSelected[docRank - 1] = Float.max(similarityToLastSelected, maxSimilarityToSelected[docRank - 1]); | ||||||||||
|
|
||||||||||
| // compute MMR | ||||||||||
| float querySimilarityScore = querySimilarity.getOrDefault(doc.rank, 0.0f); | ||||||||||
| float mmr = (context.getLambda() * querySimilarityScore) - ((1 - context.getLambda()) * highestSimilarityScoreToSelected); | ||||||||||
| float mmr = context.getLambda() * querySimilarities[docRank - 1] - (1 - context.getLambda()) | ||||||||||
| * maxSimilarityToSelected[docRank - 1]; | ||||||||||
|
|
||||||||||
| if (mmr > thisMaxMMRScore) { | ||||||||||
| thisMaxMMRScore = mmr; | ||||||||||
| thisMaxMMRDocRank = docRank; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // cache these scores | ||||||||||
| cachedSimilarities.put(docRank, cachedScoresForDoc); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| if (thisMaxMMRDocRank >= 0) { | ||||||||||
| selectedDocRanks.add(thisMaxMMRDocRank); | ||||||||||
| prevSelectedDocRank = thisMaxMMRDocRank; | ||||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
|
|
@@ -111,50 +111,9 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException { | |||||||||
| return ret; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| private Integer getHighestRelevantDocRank(RankDoc[] docs, Map<Integer, Float> querySimilarity) { | ||||||||||
| Map.Entry<Integer, Float> highestRelevantDoc = querySimilarity.entrySet() | ||||||||||
| .stream() | ||||||||||
| .max(Comparator.comparingDouble(Map.Entry::getValue)) | ||||||||||
| .orElse(null); | ||||||||||
|
|
||||||||||
| if (highestRelevantDoc != null) { | ||||||||||
| return highestRelevantDoc.getKey(); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // no query vectors? Just use the first document in the order | ||||||||||
| return docs[0].rank; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| private float getHighestSimilarityScoreToSelectedVectors( | ||||||||||
| List<Integer> selectedDocRanks, | ||||||||||
| VectorData thisDocVector, | ||||||||||
| Map<Integer, Float> cachedScoresForDoc | ||||||||||
| ) { | ||||||||||
| float highestScore = Float.NEGATIVE_INFINITY; | ||||||||||
| for (Integer compareToDocRank : selectedDocRanks) { | ||||||||||
| Float similarityScore = cachedScoresForDoc.getOrDefault(compareToDocRank, null); | ||||||||||
| if (similarityScore == null) { | ||||||||||
| VectorData comparisonVector = context.getFieldVector(compareToDocRank); | ||||||||||
| if (comparisonVector != null) { | ||||||||||
| if (comparisonVector.size() == 0) { | ||||||||||
| cachedScoresForDoc.put(compareToDocRank, Float.NEGATIVE_INFINITY); | ||||||||||
| continue; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| similarityScore = getVectorComparisonScore(similarityFunction, thisDocVector, comparisonVector); | ||||||||||
| cachedScoresForDoc.put(compareToDocRank, similarityScore); | ||||||||||
| } | ||||||||||
| } | ||||||||||
| if (similarityScore != null && similarityScore > highestScore) { | ||||||||||
| highestScore = similarityScore; | ||||||||||
| } | ||||||||||
| } | ||||||||||
| return highestScore == Float.NEGATIVE_INFINITY ? 0.0f : highestScore; | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💭 minor edge case, just from comparing implementations (not sure if it's a valid one). In the previous code if there was no valid similarity score, we would have 0.0f, but now we will get
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't actually get to this path, because as we iterate through candidates, we skip those that don't have a vector value: Lines 67 to 70 in dc43d5c
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh right 🤦 |
||||||||||
| } | ||||||||||
|
|
||||||||||
| protected Map<Integer, Float> getQuerySimilarityForDocs(RankDoc[] docs) { | ||||||||||
| Map<Integer, Float> querySimilarity = new HashMap<>(); | ||||||||||
|
|
||||||||||
| protected float[] getQuerySimilarityForDocs(RankDoc[] docs) { | ||||||||||
| float[] querySimilarity = new float[docs.length]; | ||||||||||
| Arrays.fill(querySimilarity, 0.0f); | ||||||||||
| VectorData queryVector = context.getQueryVector(); | ||||||||||
| if (queryVector == null) { | ||||||||||
| return querySimilarity; | ||||||||||
|
|
@@ -163,8 +122,7 @@ protected Map<Integer, Float> getQuerySimilarityForDocs(RankDoc[] docs) { | |||||||||
| for (RankDoc doc : docs) { | ||||||||||
| VectorData vectorData = context.getFieldVector(doc.rank); | ||||||||||
| if (vectorData != null) { | ||||||||||
| float querySimilarityScore = getVectorComparisonScore(similarityFunction, vectorData, queryVector); | ||||||||||
| querySimilarity.put(doc.rank, querySimilarityScore); | ||||||||||
| querySimilarity[doc.rank - 1] = getVectorComparisonScore(similarityFunction, vectorData, queryVector); | ||||||||||
| } | ||||||||||
| } | ||||||||||
| return querySimilarity; | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I get at what the optimization is here -- you're only comparing the current document to the last selected document, correct? (the original implementation, and the implementation in the paper, computes MMR in respect to all the previously selected documents)...
I think this will work, but it's still unsure in my head if it will produce the most correct results...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still compute MMR wrt all previously selected documents.
We keep an array of the computed max similarity between each doc and the selected set.
Then as we select new diversified docs and we iterate through the remaining docs to find a new candidate:
maxSimilarityToSelected[docRank - 1]for the current doc.maxSimilarityToSelectedvalue for the current doc.