From 69b1a9ac2df4bf3002e9fac385a957cda2b606d4 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 18 Dec 2025 13:40:15 -0500 Subject: [PATCH 1/3] Always do bulk scoring for rescoring when possible --- ...ectIOCapableLucene99FlatVectorsFormat.java | 115 ++++++++++++++++-- 1 file changed, 104 insertions(+), 11 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java index 06572b7de608e..d78fafaff6e85 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java @@ -80,23 +80,36 @@ public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectI new Lucene99FlatBulkScoringVectorsReader( directIOState, new Lucene99FlatVectorsReader(directIOState, vectorsScorer), - vectorsScorer + vectorsScorer, + true ), - new Lucene99FlatVectorsReader(state, vectorsScorer) + new Lucene99FlatBulkScoringVectorsReader(state, new Lucene99FlatVectorsReader(state, vectorsScorer), vectorsScorer, false) ); } else { - return new Lucene99FlatVectorsReader(state, vectorsScorer); + return new Lucene99FlatBulkScoringVectorsReader( + state, + new Lucene99FlatVectorsReader(state, vectorsScorer), + vectorsScorer, + false + ); } } static class Lucene99FlatBulkScoringVectorsReader extends FlatVectorsReader { private final Lucene99FlatVectorsReader inner; private final SegmentReadState state; + private final boolean forcePreFetching; - Lucene99FlatBulkScoringVectorsReader(SegmentReadState state, Lucene99FlatVectorsReader inner, FlatVectorsScorer scorer) { + Lucene99FlatBulkScoringVectorsReader( + SegmentReadState state, + Lucene99FlatVectorsReader inner, + FlatVectorsScorer scorer, + boolean forcePreFetching + ) { super(scorer); this.inner = inner; this.state = state; + this.forcePreFetching = forcePreFetching; } @Override @@ -126,7 +139,7 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { return null; } FieldInfo info = state.fieldInfos.fieldInfo(field); - return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), vectorScorer); + return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), vectorScorer, forcePreFetching); } @Override @@ -145,8 +158,14 @@ static class RescorerOffHeapVectorValues extends FloatVectorValues implements Bu private final FloatVectorValues inner; private final IndexInput inputSlice; private final FlatVectorsScorer scorer; + private final boolean forcePreFetching; - RescorerOffHeapVectorValues(FloatVectorValues inner, VectorSimilarityFunction similarityFunction, FlatVectorsScorer scorer) { + RescorerOffHeapVectorValues( + FloatVectorValues inner, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer scorer, + boolean forcePreFetching + ) { this.inner = inner; if (inner instanceof HasIndexSlice slice) { this.inputSlice = slice.getSlice(); @@ -155,6 +174,7 @@ static class RescorerOffHeapVectorValues extends FloatVectorValues implements Bu } this.similarityFunction = similarityFunction; this.scorer = scorer; + this.forcePreFetching = forcePreFetching; } @Override @@ -179,7 +199,7 @@ public DocIndexIterator iterator() { @Override public RescorerOffHeapVectorValues copy() throws IOException { - return new RescorerOffHeapVectorValues(inner.copy(), similarityFunction, scorer); + return new RescorerOffHeapVectorValues(inner.copy(), similarityFunction, scorer, forcePreFetching); } @Override @@ -191,7 +211,9 @@ public BulkVectorScorer bulkRescorer(float[] target) throws IOException { public BulkVectorScorer bulkScorer(float[] target) throws IOException { DocIndexIterator indexIterator = inner.iterator(); RandomVectorScorer randomScorer = scorer.getRandomVectorScorer(similarityFunction, inner, target); - return new PreFetchingFloatBulkScorer(randomScorer, indexIterator, inputSlice, dimension() * Float.BYTES); + return forcePreFetching + ? new PreFetchingFloatBulkVectorScorer(randomScorer, indexIterator, inputSlice, dimension() * Float.BYTES) + : new FloatBulkVectorScorer(randomScorer, indexIterator); } @Override @@ -200,7 +222,33 @@ public VectorScorer scorer(float[] target) throws IOException { } } - private record PreFetchingFloatBulkScorer( + private record FloatBulkVectorScorer(RandomVectorScorer inner, KnnVectorValues.DocIndexIterator indexIterator) + implements + BulkScorableVectorValues.BulkVectorScorer { + + @Override + public float score() throws IOException { + return inner.score(indexIterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return indexIterator; + } + + @Override + public BulkScorer bulkScore(DocIdSetIterator matchingDocs) throws IOException { + DocIdSetIterator conjunctionScorer = matchingDocs == null + ? indexIterator + : ConjunctionUtils.intersectIterators(List.of(matchingDocs, indexIterator)); + if (conjunctionScorer.docID() == -1) { + conjunctionScorer.nextDoc(); + } + return new FloatBulkScorer(inner, 32, indexIterator, conjunctionScorer); + } + } + + private record PreFetchingFloatBulkVectorScorer( RandomVectorScorer inner, KnnVectorValues.DocIndexIterator indexIterator, IndexInput inputSlice, @@ -225,11 +273,56 @@ public BulkScorer bulkScore(DocIdSetIterator matchingDocs) throws IOException { if (conjunctionScorer.docID() == -1) { conjunctionScorer.nextDoc(); } - return new FloatBulkScorer(inner, inputSlice, byteSize, 32, indexIterator, conjunctionScorer); + return new PrefetchingFloatBulkScorer(inner, inputSlice, byteSize, 32, indexIterator, conjunctionScorer); } } private static class FloatBulkScorer implements BulkScorableVectorValues.BulkVectorScorer.BulkScorer { + private final KnnVectorValues.DocIndexIterator indexIterator; + private final DocIdSetIterator matchingDocs; + private final RandomVectorScorer inner; + private final int bulkSize; + private final int[] docBuffer; + private final float[] scoreBuffer; + + FloatBulkScorer(RandomVectorScorer fvv, int bulkSize, KnnVectorValues.DocIndexIterator iterator, DocIdSetIterator matchingDocs) { + this.indexIterator = iterator; + this.matchingDocs = matchingDocs; + this.inner = fvv; + this.bulkSize = bulkSize; + this.docBuffer = new int[bulkSize]; + this.scoreBuffer = new float[bulkSize]; + } + + @Override + public void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException { + buffer.growNoCopy(nextCount); + int size = 0; + for (int doc = matchingDocs.docID(); doc != DocIdSetIterator.NO_MORE_DOCS && size < nextCount; doc = matchingDocs.nextDoc()) { + if (liveDocs == null || liveDocs.get(doc)) { + buffer.docs[size++] = indexIterator.index(); + } + } + final int loopBound = size - (size % bulkSize); + int i = 0; + for (; i < loopBound; i += bulkSize) { + System.arraycopy(buffer.docs, i, docBuffer, 0, bulkSize); + inner.bulkScore(docBuffer, scoreBuffer, bulkSize); + System.arraycopy(scoreBuffer, 0, buffer.features, i, bulkSize); + } + final int countLeft = size - i; + System.arraycopy(buffer.docs, i, docBuffer, 0, countLeft); + inner.bulkScore(docBuffer, scoreBuffer, countLeft); + System.arraycopy(scoreBuffer, 0, buffer.features, i, countLeft); + buffer.size = size; + // fix the docIds in buffer + for (int j = 0; j < size; j++) { + buffer.docs[j] = inner.ordToDoc(buffer.docs[j]); + } + } + } + + private static class PrefetchingFloatBulkScorer implements BulkScorableVectorValues.BulkVectorScorer.BulkScorer { private final KnnVectorValues.DocIndexIterator indexIterator; private final DocIdSetIterator matchingDocs; private final RandomVectorScorer inner; @@ -239,7 +332,7 @@ private static class FloatBulkScorer implements BulkScorableVectorValues.BulkVec private final int[] docBuffer; private final float[] scoreBuffer; - FloatBulkScorer( + PrefetchingFloatBulkScorer( RandomVectorScorer fvv, IndexInput inputSlice, int byteSize, From 4ce8f40850afc5bf6fd83e0d0589c916ffb78d51 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 18 Dec 2025 13:43:49 -0500 Subject: [PATCH 2/3] Update docs/changelog/139777.yaml --- docs/changelog/139777.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/139777.yaml diff --git a/docs/changelog/139777.yaml b/docs/changelog/139777.yaml new file mode 100644 index 0000000000000..0ce8bb50e5435 --- /dev/null +++ b/docs/changelog/139777.yaml @@ -0,0 +1,5 @@ +pr: 139777 +summary: Always do bulk scoring for rescoring when possible +area: Vector Search +type: enhancement +issues: [] From 891eb987ee31de9638fce58f32e3a6b7cdb65f5b Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 18 Dec 2025 16:39:03 -0500 Subject: [PATCH 3/3] iter --- ...ectIOCapableLucene99FlatVectorsFormat.java | 45 ++++++++++++++++++- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java index d78fafaff6e85..63d2ddbdf37eb 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java @@ -21,9 +21,11 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.ConjunctionUtils; import org.apache.lucene.search.DocAndFloatFeatureBuffer; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; @@ -36,6 +38,7 @@ import java.io.IOException; import java.util.List; +import java.util.Map; public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat { @@ -117,6 +120,36 @@ public void close() throws IOException { inner.close(); } + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + return inner.getOffHeapByteSize(fieldInfo); + } + + @Override + public void finishMerge() throws IOException { + inner.finishMerge(); + } + + @Override + public FlatVectorsReader getMergeInstance() throws IOException { + return inner.getMergeInstance(); + } + + @Override + public FlatVectorsScorer getFlatVectorScorer() { + return inner.getFlatVectorScorer(); + } + + @Override + public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { + inner.search(field, target, knnCollector, acceptDocs); + } + + @Override + public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { + inner.search(field, target, knnCollector, acceptDocs); + } + @Override public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { return inner.getRandomVectorScorer(field, target); @@ -135,9 +168,12 @@ public void checkIntegrity() throws IOException { @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { FloatVectorValues vectorValues = inner.getFloatVectorValues(field); - if (vectorValues == null || vectorValues.size() == 0) { + if (vectorValues == null) { return null; } + if (vectorValues.size() == 0) { + return vectorValues; + } FieldInfo info = state.fieldInfos.fieldInfo(field); return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), vectorScorer, forcePreFetching); } @@ -211,7 +247,7 @@ public BulkVectorScorer bulkRescorer(float[] target) throws IOException { public BulkVectorScorer bulkScorer(float[] target) throws IOException { DocIndexIterator indexIterator = inner.iterator(); RandomVectorScorer randomScorer = scorer.getRandomVectorScorer(similarityFunction, inner, target); - return forcePreFetching + return forcePreFetching && inputSlice != null ? new PreFetchingFloatBulkVectorScorer(randomScorer, indexIterator, inputSlice, dimension() * Float.BYTES) : new FloatBulkVectorScorer(randomScorer, indexIterator); } @@ -220,6 +256,11 @@ public BulkVectorScorer bulkScorer(float[] target) throws IOException { public VectorScorer scorer(float[] target) throws IOException { return inner.scorer(target); } + + @Override + public int ordToDoc(int ord) { + return inner.ordToDoc(ord); + } } private record FloatBulkVectorScorer(RandomVectorScorer inner, KnnVectorValues.DocIndexIterator indexIterator)