Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ synchronized void next(float[] dest) throws IOException {

synchronized void next(byte[] dest) throws IOException {
readNext();
bytes.get(dest);
bytes.get(dest).position(0);
}
}
}
22 changes: 11 additions & 11 deletions qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,10 @@ private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes, SearchParameter
long startNS = System.nanoTime();
// TODO: enable computing NN from high precision vectors when
// checking low-precision recall
int[][] nn;
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
nn = computeExactNNByte(queryPath, filterQuery, vectorFileOffsetBytes);
} else {
nn = computeExactNN(queryPath, filterQuery, searchParameters.topK(), vectorFileOffsetBytes);
}
int[][] nn = switch (vectorEncoding) {
case BYTE -> computeExactNNByte(queryPath, filterQuery, searchParameters.topK(), vectorFileOffsetBytes);
case FLOAT32 -> computeExactNN(queryPath, filterQuery, searchParameters.topK(), vectorFileOffsetBytes);
};
writeExactNN(nn, nnPath);
long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); // ns -> ms
logger.info("computed " + numQueryVectors + " exact matches in " + elapsedMS + " ms");
Expand Down Expand Up @@ -562,7 +560,7 @@ private int[][] computeExactNN(Path queryPath, Query filterQuery, int topK, int
}
}

private int[][] computeExactNNByte(Path queryPath, Query filterQuery, int vectorFileOffsetBytes) throws IOException {
private int[][] computeExactNNByte(Path queryPath, Query filterQuery, int topK, int vectorFileOffsetBytes) throws IOException {
int[][] result = new int[numQueryVectors][];
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
List<Callable<Void>> tasks = new ArrayList<>();
Expand All @@ -571,7 +569,7 @@ private int[][] computeExactNNByte(Path queryPath, Query filterQuery, int vector
for (int i = 0; i < numQueryVectors; i++) {
byte[] queryVector = new byte[dim];
queryReader.next(queryVector);
tasks.add(new ComputeNNByteTask(i, queryVector, result, reader, filterQuery, similarityFunction));
tasks.add(new ComputeNNByteTask(i, topK, queryVector, result, reader, filterQuery, similarityFunction));
}
ForkJoinPool.commonPool().invokeAll(tasks);
}
Expand Down Expand Up @@ -637,11 +635,13 @@ static class ComputeNNByteTask implements Callable<Void> {
private final byte[] query;
private final int[][] result;
private final IndexReader reader;
private final Query filterQuery;
private final VectorSimilarityFunction similarityFunction;
private final Query filterQuery;
private final int topK;

ComputeNNByteTask(
int queryOrd,
int topK,
byte[] query,
int[][] result,
IndexReader reader,
Expand All @@ -652,14 +652,14 @@ static class ComputeNNByteTask implements Callable<Void> {
this.query = query;
this.result = result;
this.reader = reader;
this.filterQuery = filterQuery;
this.similarityFunction = similarityFunction;
this.filterQuery = filterQuery;
this.topK = topK;
}

@Override
public Void call() {
IndexSearcher searcher = new IndexSearcher(reader);
int topK = result[0].length;
try {
var queryVector = new ConstKnnByteVectorValueSource(query);
var docVectors = new ByteKnnVectorFieldSource(VECTOR_FIELD);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ public Builder setDirectoryType(String directoryType) {
"<corpus_2>.fvec"
],
"queries": "<queries>.fvec",
"vector_encoding": "float32", // optional, default float32
"dimensions": 512,
"vector_space": "cosine",
"num_doc_vectors": 10000000,
Expand Down Expand Up @@ -568,6 +569,7 @@ private void resolveDataset() throws Exception {
).getFirst();
}

Object vectorEncoding = dsData.get("vector_encoding");
String vectorSpace = dsData.get("vector_space").toString();
int numDocVectors = ((Number) dsData.get("num_doc_vectors")).intValue();
int numQueryVectors = ((Number) dsData.get("num_query_vectors")).intValue();
Expand All @@ -583,6 +585,10 @@ private void resolveDataset() throws Exception {

docVectors = data;
queryVectors = queries;
// vector encoding is optional (default float32)
if (vectorEncoding != null) {
setVectorEncoding(vectorEncoding.toString());
}
setDimensions(-1); // dataset dimensions is documentation, the tester reads the dimensions from the fvec files
// vector space might already be set explicitly from the config file
if (this.vectorSpace == null) {
Expand Down