diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ByteVectorScorerFactoryTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ByteVectorScorerFactoryTests.java index c2a6a734863bf..b1665c60b2aba 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ByteVectorScorerFactoryTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ByteVectorScorerFactoryTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.simdvec; +import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; @@ -24,7 +25,10 @@ import java.util.Objects; import java.util.Random; import java.util.function.IntFunction; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import static org.elasticsearch.simdvec.VectorSimilarityType.COSINE; import static org.elasticsearch.simdvec.VectorSimilarityType.DOT_PRODUCT; import static org.elasticsearch.simdvec.VectorSimilarityType.EUCLIDEAN; import static org.elasticsearch.simdvec.VectorSimilarityType.MAXIMUM_INNER_PRODUCT; @@ -133,6 +137,156 @@ public void testDatasetGreaterThanChunkSize() throws IOException { } } + public void testSupplierBulkWithMMap() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + try (var dir = new MMapDirectory(createTempDir("testBulkWithMMap"))) { + testSupplierBulkImpl(dir); + } + } + + private void testSupplierBulkImpl(Directory dir) throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + final int dims = randomIntBetween(1, 4096); + final int size = randomIntBetween(2, 100); + final byte[][] vectors = new byte[size][]; + String fileName = "testBulk-" + dir.getClass().getSimpleName() + "-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + byte[] vec = randomByteArrayOfLength(dims); + out.writeBytes(vec, vec.length); + vectors[i] = vec; + } + CodecUtil.writeFooter(out); + } + List ids = IntStream.range(0, size).boxed().collect(Collectors.toList()); + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int[] nodes = shuffledList(ids).stream().mapToInt(i -> i).toArray(); + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, in, sim.function()); + float[] expected = new float[size]; + float[] scores = new float[size]; + for (int i = 0; i < size; i++) { + expected[i] = luceneScore(sim, vectors[idx0], vectors[nodes[i]]); + } + var supplier = factory.getByteVectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx0); + scorer.bulkScore(nodes, scores, nodes.length); + for (int i = 0; i < size; i++) { + double expectedDelta = Math.max(Math.abs(expected[i]) * DELTA, DELTA); + assertThat(sim.toString(), (double) scores[i], closeTo(expected[i], expectedDelta)); + // assert single scoring returns the same expected score as bulk + assertThat(sim.toString(), (double) scorer.score(nodes[i]), closeTo(expected[i], expectedDelta)); + } + } + } + } + } + + // -- Query-side scorer tests (ByteVectorScorer via getByteVectorScorer, JDK 22+) -- + // These test the query scorer which accepts both MMap and DirectAccessInput (SNAP). + + public void testScorerWithMMap() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + assumeTrue("scorer only supported on JDK 22+", Runtime.version().feature() >= 22); + try (var dir = new MMapDirectory(createTempDir("testScorerWithMMap"))) { + testScorerImpl(dir); + } + } + + private void testScorerImpl(Directory dir) throws IOException { + var factory = AbstractVectorTestCase.factory.get(); + final int dims = randomIntBetween(1, 4096); + final int size = randomIntBetween(2, 100); + final byte[][] vectors = new byte[size][]; + final byte[] queryVector = randomByteArrayOfLength(dims); + + String fileName = "testScorerImpl-" + dir.getClass().getSimpleName() + "-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + byte[] vec = randomByteArrayOfLength(dims); + out.writeBytes(vec, vec.length); + vectors[i] = vec; + } + CodecUtil.writeFooter(out); + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx = randomIntBetween(0, size - 1); + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, COSINE, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, in, sim.function()); + float expected = luceneScore(sim, queryVector, vectors[idx]); + var scorer = factory.getByteVectorScorer(sim.function(), values, queryVector).get(); + double expectedDelta = Math.max(Math.abs(expected) * DELTA, DELTA); + assertThat(sim.toString(), (double) scorer.score(idx), closeTo(expected, expectedDelta)); + } + } + } + } + + public void testScorerBulkWithMMap() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + assumeTrue("scorer only supported on JDK 22+", Runtime.version().feature() >= 22); + try (var dir = new MMapDirectory(createTempDir("testScorerBulkWithMMap"))) { + testScorerBulkImpl(dir); + } + } + + public void testScorerBulkFallback() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + assumeTrue("scorer only supported on JDK 22+", Runtime.version().feature() >= 22); + // Small chunk size forces multi-segment mmap; segmentSliceOrNull(0, length) returns null, + // so bulkScoreWithSparse falls back to super.bulkScore() (one-at-a-time scoring). + try (var dir = new MMapDirectory(createTempDir("testScorerBulkFallback"), 32)) { + testScorerBulkImpl(dir); + } + } + + private void testScorerBulkImpl(Directory dir) throws IOException { + var factory = AbstractVectorTestCase.factory.get(); + final int dims = randomIntBetween(64, 4096); + final int size = randomIntBetween(2, 100); + final byte[][] vectors = new byte[size][]; + final byte[] queryVector = randomByteArrayOfLength(dims); + + String fileName = "testScorerBulk-" + dir.getClass().getSimpleName() + "-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + byte[] vec = randomByteArrayOfLength(dims); + out.writeBytes(vec, vec.length); + vectors[i] = vec; + } + CodecUtil.writeFooter(out); + } + List ids = IntStream.range(0, size).boxed().collect(Collectors.toList()); + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int[] nodes = shuffledList(ids).stream().mapToInt(i -> i).toArray(); + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, COSINE, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, in, sim.function()); + float[] expected = new float[size]; + float[] scores = new float[size]; + for (int i = 0; i < size; i++) { + expected[i] = luceneScore(sim, queryVector, vectors[nodes[i]]); + } + var scorer = factory.getByteVectorScorer(sim.function(), values, queryVector).get(); + scorer.bulkScore(nodes, scores, nodes.length); + for (int i = 0; i < size; i++) { + double expectedDelta = Math.max(Math.abs(expected[i]) * DELTA, DELTA); + assertThat(sim.toString(), (double) scores[i], closeTo(expected[i], expectedDelta)); + // assert single scoring returns the same expected score as bulk + assertThat(sim.toString(), (double) scorer.score(nodes[i]), closeTo(expected[i], expectedDelta)); + } + } + } + } + } + static ByteVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) { return new OffHeapByteVectorValues.DenseOffHeapVectorValues(dims, size, in, dims, null, sim); }