Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ public float decodeScore(long heapValue) {
return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32));
}

private int decodeNodeId(long heapValue) {
public int decodeNodeId(long heapValue) {
return (int) ~(order.apply(heapValue));
}

Expand All @@ -147,13 +147,13 @@ public long popRaw() {
* removes the current top element, returns its node id and adds the new element
* to the queue.
* */
public int popAndAddRaw(long raw) {
public long popRawAndAddRaw(long raw) {
long top = heap.top();
if (raw < top) {
return decodeNodeId(raw);
}
heap.updateTop(raw);
return decodeNodeId(top);
return top;
}

public void clear() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors.diskbbq;

import java.io.IOException;

/**
* An iterator over centroids that provides posting list metadata.
*/
public interface CentroidIterator {
boolean hasNext();

PostingMetadata nextPosting() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,14 @@ public boolean hasNext() {
}

@Override
public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException {
int centroidOrdinal = neighborQueue.pop();
public PostingMetadata nextPosting() throws IOException {
long centroidOrdinalAndScore = neighborQueue.popRaw();
int centroidOrdinal = neighborQueue.decodeNodeId(centroidOrdinalAndScore);
float score = neighborQueue.decodeScore(centroidOrdinalAndScore);
centroids.seek(offset + (long) Long.BYTES * 2 * centroidOrdinal);
long postingListOffset = centroids.readLong();
long postingListLength = centroids.readLong();
return new CentroidOffsetAndLength(postingListOffset, postingListLength);
return new PostingMetadata(postingListOffset, postingListLength, centroidOrdinal, score);
}
};
}
Expand Down Expand Up @@ -244,18 +246,20 @@ public boolean hasNext() {
}

@Override
public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException {
int centroidOrdinal = nextCentroid();
public PostingMetadata nextPosting() throws IOException {
long centroidOrdinalAndScore = nextCentroid();
int centroidOrdinal = neighborQueue.decodeNodeId(centroidOrdinalAndScore);
float score = neighborQueue.decodeScore(centroidOrdinalAndScore);
centroids.seek(childrenFileOffsets + (long) Long.BYTES * 2 * centroidOrdinal);
long postingListOffset = centroids.readLong();
long postingListLength = centroids.readLong();
return new CentroidOffsetAndLength(postingListOffset, postingListLength);
return new PostingMetadata(postingListOffset, postingListLength, centroidOrdinal, score);
}

private int nextCentroid() throws IOException {
private long nextCentroid() throws IOException {
if (currentParentQueue.size() > 0) {
// return next centroid and maybe add a children from the current parent queue
return neighborQueue.popAndAddRaw(currentParentQueue.popRaw());
return neighborQueue.popRawAndAddRaw(currentParentQueue.popRaw());
} else if (parentsQueue.size() > 0) {
// current parent queue is empty, populate it again with the next parent
int pop = parentsQueue.pop();
Expand All @@ -274,7 +278,7 @@ private int nextCentroid() throws IOException {
);
return nextCentroid();
} else {
return neighborQueue.pop();
return neighborQueue.popRaw();
}
}
};
Expand Down Expand Up @@ -414,9 +418,9 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
}

@Override
public int resetPostingsScorer(long offset) throws IOException {
public int resetPostingsScorer(PostingMetadata postingMetadata) throws IOException {
quantized = false;
indexInput.seek(offset);
indexInput.seek(postingMetadata.offset());
indexInput.readFloats(centroid, 0, centroid.length);
centroidDp = Float.intBitsToFloat(indexInput.readInt());
vectors = indexInput.readVInt();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,10 @@ public final void search(String field, float[] target, KnnCollector knnCollector
// filtering? E.g. keep exploring until we hit an expected number of parent documents vs. child vectors?
while (centroidPrefetchingIterator.hasNext()
&& (maxVectorVisited > expectedDocs || knnCollector.minCompetitiveSimilarity() == Float.NEGATIVE_INFINITY)) {
// todo do we actually need to know the score???
CentroidOffsetAndLength offsetAndLength = centroidPrefetchingIterator.nextPostingListOffsetAndLength();
PostingMetadata postingMetadata = centroidPrefetchingIterator.nextPosting();
// todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing
// is enough?
expectedDocs += scorer.resetPostingsScorer(offsetAndLength.offset());
expectedDocs += scorer.resetPostingsScorer(postingMetadata);
actualDocs += scorer.visit(knnCollector);
if (knnCollector.getSearchStrategy() != null) {
knnCollector.getSearchStrategy().nextVectorsBlock();
Expand All @@ -341,8 +340,8 @@ public final void search(String field, float[] target, KnnCollector knnCollector
int filteredVectors = (int) Math.ceil(numVectors * percentFiltered);
float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f);
while (centroidPrefetchingIterator.hasNext() && (actualDocs < expectedScored || actualDocs < knnCollector.k())) {
CentroidOffsetAndLength offsetAndLength = centroidPrefetchingIterator.nextPostingListOffsetAndLength();
scorer.resetPostingsScorer(offsetAndLength.offset());
PostingMetadata postingMetadata = centroidPrefetchingIterator.nextPosting();
scorer.resetPostingsScorer(postingMetadata);
actualDocs += scorer.visit(knnCollector);
if (knnCollector.getSearchStrategy() != null) {
knnCollector.getSearchStrategy().nextVectorsBlock();
Expand Down Expand Up @@ -468,17 +467,9 @@ public int getBulkSize() {
public abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, Bits needsScoring)
throws IOException;

public record CentroidOffsetAndLength(long offset, long length) {}

public interface CentroidIterator {
boolean hasNext();

CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException;
}

public interface PostingVisitor {
/** returns the number of documents in the posting list */
int resetPostingsScorer(long offset) throws IOException;
int resetPostingsScorer(PostingMetadata metadata) throws IOException;

/** returns the number of scored documents */
int visit(KnnCollector collector) throws IOException;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors.diskbbq;

/**
* Metadata about a posting list for a centroid.
*
* @param offset The offset of the posting list in the index.
* @param length The length of the posting list in bytes.
* @param centroidOrdinal The ordinal of the centroid.
* @param centroidScore The score of the centroid.
*/
public record PostingMetadata(long offset, long length, int centroidOrdinal, float centroidScore) {}
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
package org.elasticsearch.index.codec.vectors.diskbbq;

import org.apache.lucene.store.IndexInput;
import org.elasticsearch.index.codec.vectors.diskbbq.IVFVectorsReader.CentroidIterator;
import org.elasticsearch.index.codec.vectors.diskbbq.IVFVectorsReader.CentroidOffsetAndLength;

import java.io.IOException;

Expand All @@ -29,7 +27,7 @@ public final class PrefetchingCentroidIterator implements CentroidIterator {
private final int prefetchAhead;

// Ring buffer for prefetched offsets and lengths
private final CentroidOffsetAndLength[] prefetchBuffer;
private final PostingMetadata[] prefetchBuffer;
private int readIndex = 0; // Where we read from buffer
private int writeIndex = 0; // Where we write to buffer
private int bufferCount = 0; // Number of elements in buffer
Expand Down Expand Up @@ -61,7 +59,7 @@ public PrefetchingCentroidIterator(CentroidIterator delegate, IndexInput posting
this.delegate = delegate;
this.postingListSlice = postingListSlice;
this.prefetchAhead = prefetchAhead;
this.prefetchBuffer = new CentroidOffsetAndLength[prefetchAhead];
this.prefetchBuffer = new PostingMetadata[prefetchAhead];
// Initialize buffer by prefetching up to prefetchAhead elements
fillBuffer();
}
Expand All @@ -71,7 +69,7 @@ public PrefetchingCentroidIterator(CentroidIterator delegate, IndexInput posting
*/
private void fillBuffer() throws IOException {
while (bufferCount < prefetchAhead && delegate.hasNext()) {
CentroidOffsetAndLength offsetAndLength = delegate.nextPostingListOffsetAndLength();
PostingMetadata offsetAndLength = delegate.nextPosting();
prefetchBuffer[writeIndex] = offsetAndLength;
writeIndex = (writeIndex + 1) % prefetchAhead;
bufferCount++;
Expand All @@ -87,19 +85,19 @@ public boolean hasNext() {
}

@Override
public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException {
public PostingMetadata nextPosting() throws IOException {
if (bufferCount == 0) {
throw new IllegalStateException("No more elements available");
}

// Get the next element from buffer
CentroidOffsetAndLength result = prefetchBuffer[readIndex];
PostingMetadata result = prefetchBuffer[readIndex];
readIndex = (readIndex + 1) % prefetchAhead;
bufferCount--;

// Try to fill buffer with one more element
if (delegate.hasNext()) {
CentroidOffsetAndLength offsetAndLength = delegate.nextPostingListOffsetAndLength();
PostingMetadata offsetAndLength = delegate.nextPosting();
prefetchBuffer[writeIndex] = offsetAndLength;
writeIndex = (writeIndex + 1) % prefetchAhead;
bufferCount++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
import org.elasticsearch.index.codec.vectors.GenericFlatVectorReaders;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import org.elasticsearch.index.codec.vectors.cluster.NeighborQueue;
import org.elasticsearch.index.codec.vectors.diskbbq.CentroidIterator;
import org.elasticsearch.index.codec.vectors.diskbbq.DocIdsWriter;
import org.elasticsearch.index.codec.vectors.diskbbq.IVFVectorsReader;
import org.elasticsearch.index.codec.vectors.diskbbq.PostingMetadata;
import org.elasticsearch.index.codec.vectors.diskbbq.PrefetchingCentroidIterator;
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
import org.elasticsearch.simdvec.ESNextOSQVectorsScorer;
Expand Down Expand Up @@ -275,12 +277,14 @@ public boolean hasNext() {
}

@Override
public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException {
int centroidOrdinal = neighborQueue.pop();
centroids.seek(offset + (long) Long.BYTES * 2 * centroidOrdinal);
public PostingMetadata nextPosting() throws IOException {
long centroidOrdinalAndScore = neighborQueue.popRaw();
int centroidOrd = neighborQueue.decodeNodeId(centroidOrdinalAndScore);
float score = neighborQueue.decodeScore(centroidOrdinalAndScore);
centroids.seek(offset + (long) Long.BYTES * 2 * centroidOrd);
long postingListOffset = centroids.readLong();
long postingListLength = centroids.readLong();
return new CentroidOffsetAndLength(postingListOffset, postingListLength);
return new PostingMetadata(postingListOffset, postingListLength, centroidOrd, score);
}
};
}
Expand Down Expand Up @@ -314,7 +318,7 @@ public boolean hasNext() {
}

@Override
public CentroidOffsetAndLength nextPostingListOffsetAndLength() {
public PostingMetadata nextPosting() {
return null;
}
};
Expand Down Expand Up @@ -384,18 +388,20 @@ public boolean hasNext() {
}

@Override
public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException {
int centroidOrdinal = nextCentroid();
public PostingMetadata nextPosting() throws IOException {
long centroidOrdinalAndScore = nextCentroid();
int centroidOrdinal = neighborQueue.decodeNodeId(centroidOrdinalAndScore);
float score = neighborQueue.decodeScore(centroidOrdinalAndScore);
centroids.seek(childrenFileOffsets + (long) Long.BYTES * 2 * centroidOrdinal);
long postingListOffset = centroids.readLong();
long postingListLength = centroids.readLong();
return new CentroidOffsetAndLength(postingListOffset, postingListLength);
return new PostingMetadata(postingListOffset, postingListLength, centroidOrdinal, score);
}

private int nextCentroid() throws IOException {
private long nextCentroid() throws IOException {
if (currentParentQueue.size() > 0) {
// return next centroid and maybe add a children from the current parent queue
return neighborQueue.popAndAddRaw(currentParentQueue.popRaw());
return neighborQueue.popRawAndAddRaw(currentParentQueue.popRaw());
} else if (parentsQueue.size() > 0) {
// current parent queue is empty, populate it again with the next parent
int pop = parentsQueue.pop();
Expand All @@ -416,7 +422,7 @@ private int nextCentroid() throws IOException {
);
return nextCentroid();
} else {
return neighborQueue.pop();
return neighborQueue.popRaw();
}
}
};
Expand Down Expand Up @@ -606,9 +612,9 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
}

@Override
public int resetPostingsScorer(long offset) throws IOException {
public int resetPostingsScorer(PostingMetadata metadata) throws IOException {
quantized = false;
indexInput.seek(offset);
indexInput.seek(metadata.offset());
indexInput.readFloats(centroid, 0, centroid.length);
centroidDp = Float.intBitsToFloat(indexInput.readInt());
vectors = indexInput.readVInt();
Expand Down