Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/139777.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 139777
summary: Always do bulk scoring for rescoring when possible
area: Vector Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.ConjunctionUtils;
import org.apache.lucene.search.DocAndFloatFeatureBuffer;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
Expand All @@ -36,6 +38,7 @@

import java.io.IOException;
import java.util.List;
import java.util.Map;

public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat {

Expand Down Expand Up @@ -80,30 +83,73 @@ public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectI
new Lucene99FlatBulkScoringVectorsReader(
directIOState,
new Lucene99FlatVectorsReader(directIOState, vectorsScorer),
vectorsScorer
vectorsScorer,
true
),
new Lucene99FlatVectorsReader(state, vectorsScorer)
new Lucene99FlatBulkScoringVectorsReader(state, new Lucene99FlatVectorsReader(state, vectorsScorer), vectorsScorer, false)
);
} else {
return new Lucene99FlatVectorsReader(state, vectorsScorer);
return new Lucene99FlatBulkScoringVectorsReader(
state,
new Lucene99FlatVectorsReader(state, vectorsScorer),
vectorsScorer,
false
);
}
}

static class Lucene99FlatBulkScoringVectorsReader extends FlatVectorsReader {
private final Lucene99FlatVectorsReader inner;
private final SegmentReadState state;
private final boolean forcePreFetching;

Lucene99FlatBulkScoringVectorsReader(SegmentReadState state, Lucene99FlatVectorsReader inner, FlatVectorsScorer scorer) {
Lucene99FlatBulkScoringVectorsReader(
SegmentReadState state,
Lucene99FlatVectorsReader inner,
FlatVectorsScorer scorer,
boolean forcePreFetching
) {
super(scorer);
this.inner = inner;
this.state = state;
this.forcePreFetching = forcePreFetching;
}

@Override
public void close() throws IOException {
inner.close();
}

@Override
public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {
return inner.getOffHeapByteSize(fieldInfo);
}

@Override
public void finishMerge() throws IOException {
inner.finishMerge();
}

@Override
public FlatVectorsReader getMergeInstance() throws IOException {
return inner.getMergeInstance();
}

@Override
public FlatVectorsScorer getFlatVectorScorer() {
return inner.getFlatVectorScorer();
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
inner.search(field, target, knnCollector, acceptDocs);
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
inner.search(field, target, knnCollector, acceptDocs);
}

@Override
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
return inner.getRandomVectorScorer(field, target);
Expand All @@ -122,11 +168,14 @@ public void checkIntegrity() throws IOException {
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FloatVectorValues vectorValues = inner.getFloatVectorValues(field);
if (vectorValues == null || vectorValues.size() == 0) {
if (vectorValues == null) {
return null;
}
if (vectorValues.size() == 0) {
return vectorValues;
}
FieldInfo info = state.fieldInfos.fieldInfo(field);
return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), vectorScorer);
return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), vectorScorer, forcePreFetching);
}

@Override
Expand All @@ -145,8 +194,14 @@ static class RescorerOffHeapVectorValues extends FloatVectorValues implements Bu
private final FloatVectorValues inner;
private final IndexInput inputSlice;
private final FlatVectorsScorer scorer;
private final boolean forcePreFetching;

RescorerOffHeapVectorValues(FloatVectorValues inner, VectorSimilarityFunction similarityFunction, FlatVectorsScorer scorer) {
RescorerOffHeapVectorValues(
FloatVectorValues inner,
VectorSimilarityFunction similarityFunction,
FlatVectorsScorer scorer,
boolean forcePreFetching
) {
this.inner = inner;
if (inner instanceof HasIndexSlice slice) {
this.inputSlice = slice.getSlice();
Expand All @@ -155,6 +210,7 @@ static class RescorerOffHeapVectorValues extends FloatVectorValues implements Bu
}
this.similarityFunction = similarityFunction;
this.scorer = scorer;
this.forcePreFetching = forcePreFetching;
}

@Override
Expand All @@ -179,7 +235,7 @@ public DocIndexIterator iterator() {

@Override
public RescorerOffHeapVectorValues copy() throws IOException {
return new RescorerOffHeapVectorValues(inner.copy(), similarityFunction, scorer);
return new RescorerOffHeapVectorValues(inner.copy(), similarityFunction, scorer, forcePreFetching);
}

@Override
Expand All @@ -191,16 +247,49 @@ public BulkVectorScorer bulkRescorer(float[] target) throws IOException {
public BulkVectorScorer bulkScorer(float[] target) throws IOException {
DocIndexIterator indexIterator = inner.iterator();
RandomVectorScorer randomScorer = scorer.getRandomVectorScorer(similarityFunction, inner, target);
return new PreFetchingFloatBulkScorer(randomScorer, indexIterator, inputSlice, dimension() * Float.BYTES);
return forcePreFetching && inputSlice != null
? new PreFetchingFloatBulkVectorScorer(randomScorer, indexIterator, inputSlice, dimension() * Float.BYTES)
: new FloatBulkVectorScorer(randomScorer, indexIterator);
}

@Override
public VectorScorer scorer(float[] target) throws IOException {
return inner.scorer(target);
}

@Override
public int ordToDoc(int ord) {
return inner.ordToDoc(ord);
}
}

private record FloatBulkVectorScorer(RandomVectorScorer inner, KnnVectorValues.DocIndexIterator indexIterator)
implements
BulkScorableVectorValues.BulkVectorScorer {

@Override
public float score() throws IOException {
return inner.score(indexIterator.index());
}

@Override
public DocIdSetIterator iterator() {
return indexIterator;
}

@Override
public BulkScorer bulkScore(DocIdSetIterator matchingDocs) throws IOException {
DocIdSetIterator conjunctionScorer = matchingDocs == null
? indexIterator
: ConjunctionUtils.intersectIterators(List.of(matchingDocs, indexIterator));
if (conjunctionScorer.docID() == -1) {
conjunctionScorer.nextDoc();
}
return new FloatBulkScorer(inner, 32, indexIterator, conjunctionScorer);
}
}

private record PreFetchingFloatBulkScorer(
private record PreFetchingFloatBulkVectorScorer(
RandomVectorScorer inner,
KnnVectorValues.DocIndexIterator indexIterator,
IndexInput inputSlice,
Expand All @@ -225,11 +314,56 @@ public BulkScorer bulkScore(DocIdSetIterator matchingDocs) throws IOException {
if (conjunctionScorer.docID() == -1) {
conjunctionScorer.nextDoc();
}
return new FloatBulkScorer(inner, inputSlice, byteSize, 32, indexIterator, conjunctionScorer);
return new PrefetchingFloatBulkScorer(inner, inputSlice, byteSize, 32, indexIterator, conjunctionScorer);
}
}

private static class FloatBulkScorer implements BulkScorableVectorValues.BulkVectorScorer.BulkScorer {
private final KnnVectorValues.DocIndexIterator indexIterator;
private final DocIdSetIterator matchingDocs;
private final RandomVectorScorer inner;
private final int bulkSize;
private final int[] docBuffer;
private final float[] scoreBuffer;

FloatBulkScorer(RandomVectorScorer fvv, int bulkSize, KnnVectorValues.DocIndexIterator iterator, DocIdSetIterator matchingDocs) {
this.indexIterator = iterator;
this.matchingDocs = matchingDocs;
this.inner = fvv;
this.bulkSize = bulkSize;
this.docBuffer = new int[bulkSize];
this.scoreBuffer = new float[bulkSize];
}

@Override
public void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException {
buffer.growNoCopy(nextCount);
int size = 0;
for (int doc = matchingDocs.docID(); doc != DocIdSetIterator.NO_MORE_DOCS && size < nextCount; doc = matchingDocs.nextDoc()) {
if (liveDocs == null || liveDocs.get(doc)) {
buffer.docs[size++] = indexIterator.index();
}
}
final int loopBound = size - (size % bulkSize);
int i = 0;
for (; i < loopBound; i += bulkSize) {
System.arraycopy(buffer.docs, i, docBuffer, 0, bulkSize);
inner.bulkScore(docBuffer, scoreBuffer, bulkSize);
System.arraycopy(scoreBuffer, 0, buffer.features, i, bulkSize);
}
final int countLeft = size - i;
System.arraycopy(buffer.docs, i, docBuffer, 0, countLeft);
inner.bulkScore(docBuffer, scoreBuffer, countLeft);
System.arraycopy(scoreBuffer, 0, buffer.features, i, countLeft);
buffer.size = size;
// fix the docIds in buffer
for (int j = 0; j < size; j++) {
buffer.docs[j] = inner.ordToDoc(buffer.docs[j]);
}
}
}

private static class PrefetchingFloatBulkScorer implements BulkScorableVectorValues.BulkVectorScorer.BulkScorer {
private final KnnVectorValues.DocIndexIterator indexIterator;
private final DocIdSetIterator matchingDocs;
private final RandomVectorScorer inner;
Expand All @@ -239,7 +373,7 @@ private static class FloatBulkScorer implements BulkScorableVectorValues.BulkVec
private final int[] docBuffer;
private final float[] scoreBuffer;

FloatBulkScorer(
PrefetchingFloatBulkScorer(
RandomVectorScorer fvv,
IndexInput inputSlice,
int byteSize,
Expand Down