Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

package org.elasticsearch.simdvec;

import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
Expand All @@ -24,7 +25,10 @@
import java.util.Objects;
import java.util.Random;
import java.util.function.IntFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.elasticsearch.simdvec.VectorSimilarityType.COSINE;
import static org.elasticsearch.simdvec.VectorSimilarityType.DOT_PRODUCT;
import static org.elasticsearch.simdvec.VectorSimilarityType.EUCLIDEAN;
import static org.elasticsearch.simdvec.VectorSimilarityType.MAXIMUM_INNER_PRODUCT;
Expand Down Expand Up @@ -133,6 +137,156 @@ public void testDatasetGreaterThanChunkSize() throws IOException {
}
}

public void testSupplierBulkWithMMap() throws IOException {
assumeTrue(notSupportedMsg(), supported());
try (var dir = new MMapDirectory(createTempDir("testBulkWithMMap"))) {
testSupplierBulkImpl(dir);
}
}

private void testSupplierBulkImpl(Directory dir) throws IOException {
assumeTrue(notSupportedMsg(), supported());
var factory = AbstractVectorTestCase.factory.get();

final int dims = randomIntBetween(1, 4096);
final int size = randomIntBetween(2, 100);
final byte[][] vectors = new byte[size][];
String fileName = "testBulk-" + dir.getClass().getSimpleName() + "-" + dims;
logger.info("Testing " + fileName);
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
for (int i = 0; i < size; i++) {
byte[] vec = randomByteArrayOfLength(dims);
out.writeBytes(vec, vec.length);
vectors[i] = vec;
}
CodecUtil.writeFooter(out);
}
List<Integer> ids = IntStream.range(0, size).boxed().collect(Collectors.toList());
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
for (int times = 0; times < TIMES; times++) {
int idx0 = randomIntBetween(0, size - 1);
int[] nodes = shuffledList(ids).stream().mapToInt(i -> i).toArray();
for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
var values = vectorValues(dims, size, in, sim.function());
float[] expected = new float[size];
float[] scores = new float[size];
for (int i = 0; i < size; i++) {
expected[i] = luceneScore(sim, vectors[idx0], vectors[nodes[i]]);
}
var supplier = factory.getByteVectorScorerSupplier(sim, in, values).get();
var scorer = supplier.scorer();
scorer.setScoringOrdinal(idx0);
scorer.bulkScore(nodes, scores, nodes.length);
for (int i = 0; i < size; i++) {
double expectedDelta = Math.max(Math.abs(expected[i]) * DELTA, DELTA);
assertThat(sim.toString(), (double) scores[i], closeTo(expected[i], expectedDelta));
// assert single scoring returns the same expected score as bulk
assertThat(sim.toString(), (double) scorer.score(nodes[i]), closeTo(expected[i], expectedDelta));
}
}
}
}
}

// -- Query-side scorer tests (ByteVectorScorer via getByteVectorScorer, JDK 22+) --
// These test the query scorer which accepts both MMap and DirectAccessInput (SNAP).

public void testScorerWithMMap() throws IOException {
assumeTrue(notSupportedMsg(), supported());
assumeTrue("scorer only supported on JDK 22+", Runtime.version().feature() >= 22);
try (var dir = new MMapDirectory(createTempDir("testScorerWithMMap"))) {
testScorerImpl(dir);
}
}

private void testScorerImpl(Directory dir) throws IOException {
var factory = AbstractVectorTestCase.factory.get();
final int dims = randomIntBetween(1, 4096);
final int size = randomIntBetween(2, 100);
final byte[][] vectors = new byte[size][];
final byte[] queryVector = randomByteArrayOfLength(dims);

String fileName = "testScorerImpl-" + dir.getClass().getSimpleName() + "-" + dims;
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
for (int i = 0; i < size; i++) {
byte[] vec = randomByteArrayOfLength(dims);
out.writeBytes(vec, vec.length);
vectors[i] = vec;
}
CodecUtil.writeFooter(out);
}
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
for (int times = 0; times < TIMES; times++) {
int idx = randomIntBetween(0, size - 1);
for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, COSINE, MAXIMUM_INNER_PRODUCT)) {
var values = vectorValues(dims, size, in, sim.function());
float expected = luceneScore(sim, queryVector, vectors[idx]);
var scorer = factory.getByteVectorScorer(sim.function(), values, queryVector).get();
double expectedDelta = Math.max(Math.abs(expected) * DELTA, DELTA);
assertThat(sim.toString(), (double) scorer.score(idx), closeTo(expected, expectedDelta));
}
}
}
}

public void testScorerBulkWithMMap() throws IOException {
assumeTrue(notSupportedMsg(), supported());
assumeTrue("scorer only supported on JDK 22+", Runtime.version().feature() >= 22);
try (var dir = new MMapDirectory(createTempDir("testScorerBulkWithMMap"))) {
testScorerBulkImpl(dir);
}
}

public void testScorerBulkFallback() throws IOException {
assumeTrue(notSupportedMsg(), supported());
assumeTrue("scorer only supported on JDK 22+", Runtime.version().feature() >= 22);
// Small chunk size forces multi-segment mmap; segmentSliceOrNull(0, length) returns null,
// so bulkScoreWithSparse falls back to super.bulkScore() (one-at-a-time scoring).
try (var dir = new MMapDirectory(createTempDir("testScorerBulkFallback"), 32)) {
testScorerBulkImpl(dir);
}
}

private void testScorerBulkImpl(Directory dir) throws IOException {
var factory = AbstractVectorTestCase.factory.get();
final int dims = randomIntBetween(64, 4096);
final int size = randomIntBetween(2, 100);
final byte[][] vectors = new byte[size][];
final byte[] queryVector = randomByteArrayOfLength(dims);

String fileName = "testScorerBulk-" + dir.getClass().getSimpleName() + "-" + dims;
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
for (int i = 0; i < size; i++) {
byte[] vec = randomByteArrayOfLength(dims);
out.writeBytes(vec, vec.length);
vectors[i] = vec;
}
CodecUtil.writeFooter(out);
}
List<Integer> ids = IntStream.range(0, size).boxed().collect(Collectors.toList());
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
for (int times = 0; times < TIMES; times++) {
int[] nodes = shuffledList(ids).stream().mapToInt(i -> i).toArray();
for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, COSINE, MAXIMUM_INNER_PRODUCT)) {
var values = vectorValues(dims, size, in, sim.function());
float[] expected = new float[size];
float[] scores = new float[size];
for (int i = 0; i < size; i++) {
expected[i] = luceneScore(sim, queryVector, vectors[nodes[i]]);
}
var scorer = factory.getByteVectorScorer(sim.function(), values, queryVector).get();
scorer.bulkScore(nodes, scores, nodes.length);
for (int i = 0; i < size; i++) {
double expectedDelta = Math.max(Math.abs(expected[i]) * DELTA, DELTA);
assertThat(sim.toString(), (double) scores[i], closeTo(expected[i], expectedDelta));
// assert single scoring returns the same expected score as bulk
assertThat(sim.toString(), (double) scorer.score(nodes[i]), closeTo(expected[i], expectedDelta));
}
}
}
}
}

static ByteVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) {
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(dims, size, in, dims, null, sim);
}
Expand Down
Loading