diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle index 1fd34a0a84235..f3fa0f9ff4515 100644 --- a/benchmarks/build.gradle +++ b/benchmarks/build.gradle @@ -51,6 +51,7 @@ dependencies { api(project(':x-pack:plugin:analytics')) api(project(':x-pack:plugin:logsdb')) implementation project(path: ':libs:native') + implementation(testFixtures(project(':libs:native'))) implementation project(path: ':libs:simdvec') implementation (testFixtures(project(path: ':libs:simdvec'))) implementation project(path: ':libs:swisshash') diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/Int4BenchmarkUtils.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/Int4BenchmarkUtils.java index deec4d9ea4201..46597e90888b1 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/Int4BenchmarkUtils.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/Int4BenchmarkUtils.java @@ -12,7 +12,9 @@ import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat; import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; @@ -23,87 +25,20 @@ import java.util.concurrent.ThreadLocalRandom; import static org.elasticsearch.benchmark.vector.scorer.ScalarOperations.applyI4Corrections; -import static org.elasticsearch.benchmark.vector.scorer.ScalarOperations.dotProductI4SinglePacked; -import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.unpackNibbles; +import static org.elasticsearch.nativeaccess.Int4TestUtils.dotProductI4SinglePacked; +import static org.elasticsearch.nativeaccess.Int4TestUtils.unpackNibbles; +import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.writePackedVectorWithCorrection; public class Int4BenchmarkUtils { - /** - * In-memory implementation of {@link QuantizedByteVectorValues} for int4 (PACKED_NIBBLE) benchmarks. - * Stores pre-quantized packed nibble vectors with synthetic corrective terms. - */ - static class InMemoryInt4QuantizedByteVectorValues extends QuantizedByteVectorValues { - - private final int dims; - private final byte[][] packedVectors; - private final OptimizedScalarQuantizer.QuantizationResult[] correctiveTerms; - private final float[] centroid; - private final float centroidDP; - private final OptimizedScalarQuantizer quantizer; - - InMemoryInt4QuantizedByteVectorValues( - int dims, - byte[][] packedVectors, - OptimizedScalarQuantizer.QuantizationResult[] correctiveTerms, - float[] centroid, - float centroidDP - ) { - this.dims = dims; - this.packedVectors = packedVectors; - this.correctiveTerms = correctiveTerms; - this.centroid = centroid; - this.centroidDP = centroidDP; - this.quantizer = new OptimizedScalarQuantizer(VectorSimilarityFunction.DOT_PRODUCT); - } - - @Override - public int dimension() { - return dims; - } - - @Override - public int size() { - return packedVectors.length; - } - - @Override - public byte[] vectorValue(int ord) throws IOException { - return packedVectors[ord]; - } - - @Override - public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int vectorOrd) throws IOException { - return correctiveTerms[vectorOrd]; - } - - @Override - public OptimizedScalarQuantizer getQuantizer() { - return quantizer; - } - - @Override - public Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding getScalarEncoding() { - return Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.PACKED_NIBBLE; - } - - @Override - public float[] getCentroid() throws IOException { - return centroid; - } - - @Override - public float getCentroidDP() throws IOException { - return centroidDP; - } + static final String VECTOR_DATA_FILE = "int4-vector.data"; - @Override - public VectorScorer scorer(float[] query) throws IOException { - return null; - } - - @Override - public InMemoryInt4QuantizedByteVectorValues copy() throws IOException { - return new InMemoryInt4QuantizedByteVectorValues(dims, packedVectors, correctiveTerms, centroid, centroidDP); + static void writeI4VectorData(Directory dir, byte[][] packedVectors, OptimizedScalarQuantizer.QuantizationResult[] corrections) + throws IOException { + try (IndexOutput out = dir.createOutput(VECTOR_DATA_FILE, IOContext.DEFAULT)) { + for (int i = 0; i < packedVectors.length; i++) { + writePackedVectorWithCorrection(out, packedVectors[i], corrections[i]); + } } } @@ -142,25 +77,6 @@ public void setScoringOrdinal(int node) throws IOException { } } - static QuantizedByteVectorValues createI4QuantizedVectorValues(int dims, byte[][] packedVectors) { - var random = ThreadLocalRandom.current(); - var correctiveTerms = new OptimizedScalarQuantizer.QuantizationResult[packedVectors.length]; - for (int i = 0; i < packedVectors.length; i++) { - correctiveTerms[i] = new OptimizedScalarQuantizer.QuantizationResult( - random.nextFloat(-1f, 1f), - random.nextFloat(-1f, 1f), - random.nextFloat(-1f, 1f), - random.nextInt(0, dims * 15) - ); - } - float[] centroid = new float[dims]; - for (int i = 0; i < dims; i++) { - centroid[i] = random.nextFloat(); - } - float centroidDP = random.nextFloat(); - return new InMemoryInt4QuantizedByteVectorValues(dims, packedVectors, correctiveTerms, centroid, centroidDP); - } - static UpdateableRandomVectorScorer createI4ScalarScorer( QuantizedByteVectorValues values, VectorSimilarityFunction similarityFunction @@ -179,7 +95,7 @@ static RandomVectorScorer createI4ScalarQueryScorer( Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding encoding = values.getScalarEncoding(); byte[] queryQuantized = new byte[encoding.getDiscreteDimensions(dims)]; - float[] queryCopy = Arrays.copyOf(queryVector, queryVector.length); + float[] queryCopy = queryVector.clone(); if (similarityFunction == VectorSimilarityFunction.COSINE) { VectorUtil.l2normalize(queryCopy); } @@ -196,4 +112,50 @@ public float score(int node) throws IOException { } }; } + + static OptimizedScalarQuantizer.QuantizationResult[] generateCorrectiveTerms(int dims, int numVectors) { + var random = ThreadLocalRandom.current(); + var correctiveTerms = new OptimizedScalarQuantizer.QuantizationResult[numVectors]; + for (int i = 0; i < numVectors; i++) { + correctiveTerms[i] = new OptimizedScalarQuantizer.QuantizationResult( + random.nextFloat(-1f, 1f), + random.nextFloat(-1f, 1f), + random.nextFloat(-1f, 1f), + random.nextInt(0, dims * 15) + ); + } + return correctiveTerms; + } + + static float[] generateCentroid(int dims) { + var random = ThreadLocalRandom.current(); + float[] centroid = new float[dims]; + for (int i = 0; i < dims; i++) { + centroid[i] = random.nextFloat(); + } + return centroid; + } + + /** + * Quantizes a float query vector for use with the native Int4 scorer. + * Returns the unpacked quantized bytes (one byte per dimension, 0-15 range). + */ + static QuantizedQuery quantizeQuery(QuantizedByteVectorValues values, VectorSimilarityFunction sim, float[] queryVector) + throws IOException { + int dims = values.dimension(); + OptimizedScalarQuantizer quantizer = values.getQuantizer(); + float[] centroid = values.getCentroid(); + var encoding = values.getScalarEncoding(); + + byte[] scratch = new byte[encoding.getDiscreteDimensions(dims)]; + float[] queryCopy = queryVector.clone(); + if (sim == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(queryCopy); + } + var corrections = quantizer.scalarQuantize(queryCopy, scratch, encoding.getQueryBits(), centroid); + byte[] unpackedQuery = Arrays.copyOf(scratch, dims); + return new QuantizedQuery(unpackedQuery, corrections); + } + + record QuantizedQuery(byte[] unpackedQuery, OptimizedScalarQuantizer.QuantizationResult corrections) {} } diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/ScalarOperations.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/ScalarOperations.java index 5ef0ba5a9982a..78c334be2e20d 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/ScalarOperations.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/ScalarOperations.java @@ -75,18 +75,6 @@ static int squareDistance(byte[] a, byte[] b) { return res; } - static int dotProductI4SinglePacked(byte[] unpacked, byte[] packed) { - int total = 0; - for (int i = 0; i < packed.length; i++) { - byte packedByte = packed[i]; - byte unpacked1 = unpacked[i]; - byte unpacked2 = unpacked[i + packed.length]; - total += (packedByte & 0x0F) * unpacked2; - total += ((packedByte & 0xFF) >> 4) * unpacked1; - } - return total; - } - public static float applyI4Corrections( int rawDot, int dims, diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4Benchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4Benchmark.java index f943ec9fb08ad..15ba65fe495b1 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4Benchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4Benchmark.java @@ -11,9 +11,15 @@ import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; import org.elasticsearch.benchmark.Utils; +import org.elasticsearch.core.IOUtils; import org.elasticsearch.simdvec.VectorSimilarityType; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; @@ -29,21 +35,30 @@ import org.openjdk.jmh.annotations.Warmup; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.getScorerFactoryOrDie; import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.lucene104ScoreSupplier; import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.lucene104Scorer; import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.supportsHeapSegments; -import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.createI4QuantizedVectorValues; +import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.VECTOR_DATA_FILE; import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.createI4ScalarQueryScorer; import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.createI4ScalarScorer; -import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.packNibbles; +import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.generateCentroid; +import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.generateCorrectiveTerms; +import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.quantizeQuery; +import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.writeI4VectorData; +import static org.elasticsearch.nativeaccess.Int4TestUtils.packNibbles; +import static org.elasticsearch.simdvec.ESVectorUtil.dotProduct; +import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.createDenseInt4VectorValues; import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.randomInt4Bytes; /** * Benchmark that compares int4 packed-nibble quantized vector similarity scoring: - * scalar vs Lucene's Lucene104ScalarQuantizedVectorScorer. + * scalar vs Lucene's Lucene104ScalarQuantizedVectorScorer vs native. * Run with ./gradlew -p benchmarks run --args 'VectorScorerInt4Benchmark' */ @Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) @@ -62,28 +77,37 @@ public class VectorScorerInt4Benchmark { public int dims; public static int numVectors = 2; - @Param({ "SCALAR", "LUCENE" }) + @Param public VectorImplementation implementation; @Param({ "DOT_PRODUCT", "EUCLIDEAN" }) public VectorSimilarityType function; + private Path path; + private Directory dir; + private IndexInput in; + private UpdateableRandomVectorScorer scorer; private RandomVectorScorer queryScorer; static class VectorData { - final QuantizedByteVectorValues values; + final byte[][] packedVectors; + final OptimizedScalarQuantizer.QuantizationResult[] corrections; + final float[] centroid; + final float centroidDp; final float[] queryVector; VectorData(int dims) { - byte[][] packedVectors = new byte[numVectors][]; + packedVectors = new byte[numVectors][]; ThreadLocalRandom random = ThreadLocalRandom.current(); for (int v = 0; v < numVectors; v++) { byte[] unpacked = new byte[dims]; randomInt4Bytes(random, unpacked); packedVectors[v] = packNibbles(unpacked); } - values = createI4QuantizedVectorValues(dims, packedVectors); + corrections = generateCorrectiveTerms(dims, numVectors); + centroid = generateCentroid(dims); + centroidDp = dotProduct(centroid, centroid); queryVector = new float[dims]; for (int i = 0; i < dims; i++) { queryVector[i] = random.nextFloat(); @@ -98,7 +122,20 @@ public void setup() throws IOException { void setup(VectorData vectorData) throws IOException { VectorSimilarityFunction similarityFunction = function.function(); - var values = vectorData.values; + + path = Files.createTempDirectory("Int4ScorerBenchmark"); + dir = new MMapDirectory(path); + writeI4VectorData(dir, vectorData.packedVectors, vectorData.corrections); + in = dir.openInput(VECTOR_DATA_FILE, IOContext.DEFAULT); + + QuantizedByteVectorValues values = createDenseInt4VectorValues( + dims, + numVectors, + vectorData.centroid, + vectorData.centroidDp, + in, + similarityFunction + ); switch (implementation) { case SCALAR: @@ -111,13 +148,32 @@ void setup(VectorData vectorData) throws IOException { queryScorer = lucene104Scorer(values, similarityFunction, vectorData.queryVector); } break; + case NATIVE: + var factory = getScorerFactoryOrDie(); + scorer = factory.getInt4VectorScorerSupplier(function, in, values).orElseThrow().scorer(); + if (supportsHeapSegments()) { + var qQuery = quantizeQuery(values, similarityFunction, vectorData.queryVector); + queryScorer = factory.getInt4VectorScorer( + similarityFunction, + values, + qQuery.unpackedQuery(), + qQuery.corrections().lowerInterval(), + qQuery.corrections().upperInterval(), + qQuery.corrections().additionalCorrection(), + qQuery.corrections().quantizedComponentSum() + ).orElseThrow(); + } + break; } scorer.setScoringOrdinal(0); } @TearDown - public void teardown() throws IOException {} + public void teardown() throws IOException { + IOUtils.close(in, dir); + IOUtils.rm(path); + } @Benchmark public float score() throws IOException { diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4BulkBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4BulkBenchmark.java index e5331fa3cbdb1..8ce6cba75cda2 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4BulkBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4BulkBenchmark.java @@ -11,9 +11,15 @@ import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; import org.elasticsearch.benchmark.Utils; +import org.elasticsearch.core.IOUtils; import org.elasticsearch.simdvec.VectorSimilarityType; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; @@ -29,6 +35,8 @@ import org.openjdk.jmh.annotations.Warmup; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.Collections; import java.util.List; import java.util.concurrent.ThreadLocalRandom; @@ -36,19 +44,26 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.getScorerFactoryOrDie; import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.lucene104ScoreSupplier; import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.lucene104Scorer; import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.supportsHeapSegments; -import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.createI4QuantizedVectorValues; +import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.VECTOR_DATA_FILE; import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.createI4ScalarQueryScorer; import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.createI4ScalarScorer; -import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.packNibbles; +import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.generateCentroid; +import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.generateCorrectiveTerms; +import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.quantizeQuery; +import static org.elasticsearch.benchmark.vector.scorer.Int4BenchmarkUtils.writeI4VectorData; +import static org.elasticsearch.nativeaccess.Int4TestUtils.packNibbles; +import static org.elasticsearch.simdvec.ESVectorUtil.dotProduct; +import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.createDenseInt4VectorValues; import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.randomInt4Bytes; /** * Benchmark that compares bulk scoring of int4 packed-nibble quantized vectors: - * scalar vs Lucene's Lucene104ScalarQuantizedVectorScorer, across sequential - * and random access patterns. + * scalar vs Lucene's Lucene104ScalarQuantizedVectorScorer vs native, + * across sequential and random access patterns. * Run with ./gradlew -p benchmarks run --args 'VectorScorerInt4BulkBenchmark' */ @Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) @@ -73,12 +88,16 @@ public class VectorScorerInt4BulkBenchmark { @Param({ "16", "32", "64", "256", "1024" }) public int bulkSize; - @Param({ "SCALAR", "LUCENE" }) + @Param public VectorImplementation implementation; @Param({ "DOT_PRODUCT", "EUCLIDEAN" }) public VectorSimilarityType function; + private Path path; + private Directory dir; + private IndexInput in; + private float[] scores; private int[] ordinals; private int[] ids; @@ -89,21 +108,26 @@ public class VectorScorerInt4BulkBenchmark { static class VectorData { final int numVectorsToScore; - final QuantizedByteVectorValues values; + final byte[][] packedVectors; + final OptimizedScalarQuantizer.QuantizationResult[] corrections; + final float[] centroid; + final float centroidDp; final int[] ordinals; final int targetOrd; final float[] queryVector; VectorData(int dims, int numVectors, int numVectorsToScore) { this.numVectorsToScore = numVectorsToScore; - byte[][] packedVectors = new byte[numVectors][]; + packedVectors = new byte[numVectors][]; ThreadLocalRandom random = ThreadLocalRandom.current(); for (int v = 0; v < numVectors; v++) { byte[] unpacked = new byte[dims]; randomInt4Bytes(random, unpacked); packedVectors[v] = packNibbles(unpacked); } - values = createI4QuantizedVectorValues(dims, packedVectors); + corrections = generateCorrectiveTerms(dims, numVectors); + centroid = generateCentroid(dims); + centroidDp = dotProduct(centroid, centroid); List list = IntStream.range(0, numVectors).boxed().collect(Collectors.toList()); Collections.shuffle(list, random); @@ -124,7 +148,20 @@ public void setup() throws IOException { void setup(VectorData vectorData) throws IOException { VectorSimilarityFunction similarityFunction = function.function(); - var values = vectorData.values; + + path = Files.createTempDirectory("Int4BulkScorerBenchmark"); + dir = new MMapDirectory(path); + writeI4VectorData(dir, vectorData.packedVectors, vectorData.corrections); + in = dir.openInput(VECTOR_DATA_FILE, IOContext.DEFAULT); + + QuantizedByteVectorValues values = createDenseInt4VectorValues( + dims, + vectorData.packedVectors.length, + vectorData.centroid, + vectorData.centroidDp, + in, + similarityFunction + ); numVectorsToScore = vectorData.numVectorsToScore; scores = new float[bulkSize]; @@ -143,13 +180,32 @@ void setup(VectorData vectorData) throws IOException { queryScorer = lucene104Scorer(values, similarityFunction, vectorData.queryVector); } break; + case NATIVE: + var factory = getScorerFactoryOrDie(); + scorer = factory.getInt4VectorScorerSupplier(function, in, values).orElseThrow().scorer(); + if (supportsHeapSegments()) { + var qQuery = quantizeQuery(values, similarityFunction, vectorData.queryVector); + queryScorer = factory.getInt4VectorScorer( + similarityFunction, + values, + qQuery.unpackedQuery(), + qQuery.corrections().lowerInterval(), + qQuery.corrections().upperInterval(), + qQuery.corrections().additionalCorrection(), + qQuery.corrections().quantizedComponentSum() + ).orElseThrow(); + } + break; } scorer.setScoringOrdinal(vectorData.targetOrd); } @TearDown - public void teardown() throws IOException {} + public void teardown() throws IOException { + IOUtils.close(in, dir); + IOUtils.rm(path); + } @Benchmark public float[] scoreMultipleSequential() throws IOException { diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4OperationBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4OperationBenchmark.java index a6a222ba1f044..3f5ee399245fe 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4OperationBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4OperationBenchmark.java @@ -26,7 +26,8 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.packNibbles; +import static org.elasticsearch.nativeaccess.Int4TestUtils.dotProductI4SinglePacked; +import static org.elasticsearch.nativeaccess.Int4TestUtils.packNibbles; import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.randomInt4Bytes; /** @@ -61,7 +62,7 @@ public void init() { @Benchmark public int scalar() { - return ScalarOperations.dotProductI4SinglePacked(unpacked, packed); + return dotProductI4SinglePacked(unpacked, packed); } @Benchmark diff --git a/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4BenchmarkTests.java b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4BenchmarkTests.java index dd632058bb962..820794ac43599 100644 --- a/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4BenchmarkTests.java +++ b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4BenchmarkTests.java @@ -28,28 +28,30 @@ public VectorScorerInt4BenchmarkTests(VectorSimilarityType function, int dims) { this.dims = dims; } + private VectorScorerInt4Benchmark createBench(VectorImplementation impl, VectorScorerInt4Benchmark.VectorData data) throws Exception { + var bench = new VectorScorerInt4Benchmark(); + bench.function = function; + bench.implementation = impl; + bench.dims = dims; + bench.setup(data); + return bench; + } + public void testScores() throws Exception { for (int i = 0; i < 100; i++) { var data = new VectorScorerInt4Benchmark.VectorData(dims); - - var scalar = new VectorScorerInt4Benchmark(); - scalar.function = function; - scalar.implementation = VectorImplementation.SCALAR; - scalar.dims = dims; - scalar.setup(data); - - var lucene = new VectorScorerInt4Benchmark(); - lucene.function = function; - lucene.implementation = VectorImplementation.LUCENE; - lucene.dims = dims; - lucene.setup(data); + var scalar = createBench(VectorImplementation.SCALAR, data); + var lucene = createBench(VectorImplementation.LUCENE, data); + var nativeBench = createBench(VectorImplementation.NATIVE, data); try { float expected = scalar.score(); assertEquals("LUCENE score", expected, lucene.score(), delta); + assertEquals("NATIVE score", expected, nativeBench.score(), delta); } finally { scalar.teardown(); lucene.teardown(); + nativeBench.teardown(); } } } @@ -57,25 +59,18 @@ public void testScores() throws Exception { public void testQueryScores() throws Exception { for (int i = 0; i < 100; i++) { var data = new VectorScorerInt4Benchmark.VectorData(dims); - - var scalar = new VectorScorerInt4Benchmark(); - scalar.function = function; - scalar.implementation = VectorImplementation.SCALAR; - scalar.dims = dims; - scalar.setup(data); - - var lucene = new VectorScorerInt4Benchmark(); - lucene.function = function; - lucene.implementation = VectorImplementation.LUCENE; - lucene.dims = dims; - lucene.setup(data); + var scalar = createBench(VectorImplementation.SCALAR, data); + var lucene = createBench(VectorImplementation.LUCENE, data); + var nativeBench = createBench(VectorImplementation.NATIVE, data); try { float expected = scalar.scoreQuery(); assertEquals("LUCENE scoreQuery", expected, lucene.scoreQuery(), delta); + assertEquals("NATIVE scoreQuery", expected, nativeBench.scoreQuery(), delta); } finally { scalar.teardown(); lucene.teardown(); + nativeBench.teardown(); } } } diff --git a/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4BulkBenchmarkTests.java b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4BulkBenchmarkTests.java index 8af55d30e6a97..31b42407d8698 100644 --- a/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4BulkBenchmarkTests.java +++ b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt4BulkBenchmarkTests.java @@ -46,13 +46,16 @@ public void testSequential() throws Exception { var vectorData = new VectorScorerInt4BulkBenchmark.VectorData(dims, 1000, 200); var scalar = createBench(VectorImplementation.SCALAR, vectorData); var lucene = createBench(VectorImplementation.LUCENE, vectorData); + var nativeBench = createBench(VectorImplementation.NATIVE, vectorData); try { float[] expected = scalar.scoreMultipleSequential(); assertArrayEquals("LUCENE sequential", expected, lucene.scoreMultipleSequential(), delta); + assertArrayEquals("NATIVE sequential", expected, nativeBench.scoreMultipleSequential(), delta); } finally { scalar.teardown(); lucene.teardown(); + nativeBench.teardown(); } } } @@ -62,13 +65,16 @@ public void testRandom() throws Exception { var vectorData = new VectorScorerInt4BulkBenchmark.VectorData(dims, 1000, 200); var scalar = createBench(VectorImplementation.SCALAR, vectorData); var lucene = createBench(VectorImplementation.LUCENE, vectorData); + var nativeBench = createBench(VectorImplementation.NATIVE, vectorData); try { float[] expected = scalar.scoreMultipleRandom(); assertArrayEquals("LUCENE random", expected, lucene.scoreMultipleRandom(), delta); + assertArrayEquals("NATIVE random", expected, nativeBench.scoreMultipleRandom(), delta); } finally { scalar.teardown(); lucene.teardown(); + nativeBench.teardown(); } } } @@ -78,13 +84,16 @@ public void testQueryRandom() throws Exception { var vectorData = new VectorScorerInt4BulkBenchmark.VectorData(dims, 1000, 200); var scalar = createBench(VectorImplementation.SCALAR, vectorData); var lucene = createBench(VectorImplementation.LUCENE, vectorData); + var nativeBench = createBench(VectorImplementation.NATIVE, vectorData); try { float[] expected = scalar.scoreQueryMultipleRandom(); assertArrayEquals("LUCENE queryRandom", expected, lucene.scoreQueryMultipleRandom(), delta); + assertArrayEquals("NATIVE queryRandom", expected, nativeBench.scoreQueryMultipleRandom(), delta); } finally { scalar.teardown(); lucene.teardown(); + nativeBench.teardown(); } } } @@ -94,13 +103,16 @@ public void testSequentialBulk() throws Exception { var vectorData = new VectorScorerInt4BulkBenchmark.VectorData(dims, 1000, 200); var scalar = createBench(VectorImplementation.SCALAR, vectorData); var lucene = createBench(VectorImplementation.LUCENE, vectorData); + var nativeBench = createBench(VectorImplementation.NATIVE, vectorData); try { float[] expected = scalar.scoreMultipleSequentialBulk(); assertArrayEquals("LUCENE sequentialBulk", expected, lucene.scoreMultipleSequentialBulk(), delta); + assertArrayEquals("NATIVE sequentialBulk", expected, nativeBench.scoreMultipleSequentialBulk(), delta); } finally { scalar.teardown(); lucene.teardown(); + nativeBench.teardown(); } } } @@ -110,13 +122,16 @@ public void testRandomBulk() throws Exception { var vectorData = new VectorScorerInt4BulkBenchmark.VectorData(dims, 1000, 200); var scalar = createBench(VectorImplementation.SCALAR, vectorData); var lucene = createBench(VectorImplementation.LUCENE, vectorData); + var nativeBench = createBench(VectorImplementation.NATIVE, vectorData); try { float[] expected = scalar.scoreMultipleRandomBulk(); assertArrayEquals("LUCENE randomBulk", expected, lucene.scoreMultipleRandomBulk(), delta); + assertArrayEquals("NATIVE randomBulk", expected, nativeBench.scoreMultipleRandomBulk(), delta); } finally { scalar.teardown(); lucene.teardown(); + nativeBench.teardown(); } } } @@ -126,13 +141,16 @@ public void testQueryRandomBulk() throws Exception { var vectorData = new VectorScorerInt4BulkBenchmark.VectorData(dims, 1000, 200); var scalar = createBench(VectorImplementation.SCALAR, vectorData); var lucene = createBench(VectorImplementation.LUCENE, vectorData); + var nativeBench = createBench(VectorImplementation.NATIVE, vectorData); try { float[] expected = scalar.scoreQueryMultipleRandomBulk(); assertArrayEquals("LUCENE queryRandomBulk", expected, lucene.scoreQueryMultipleRandomBulk(), delta); + assertArrayEquals("NATIVE queryRandomBulk", expected, nativeBench.scoreQueryMultipleRandomBulk(), delta); } finally { scalar.teardown(); lucene.teardown(); + nativeBench.teardown(); } } } diff --git a/libs/native/build.gradle b/libs/native/build.gradle index 50263017c2565..18a277b2f344b 100644 --- a/libs/native/build.gradle +++ b/libs/native/build.gradle @@ -12,6 +12,7 @@ import org.elasticsearch.gradle.internal.precommit.CheckForbiddenApisTask apply plugin: 'elasticsearch.publish' apply plugin: 'elasticsearch.build' apply plugin: 'elasticsearch.mrjar' +apply plugin: 'java-test-fixtures' dependencies { api project(':libs:core') diff --git a/libs/native/libraries/build.gradle b/libs/native/libraries/build.gradle index 3ce074dc4d807..7d5c4b3e6b62b 100644 --- a/libs/native/libraries/build.gradle +++ b/libs/native/libraries/build.gradle @@ -19,7 +19,7 @@ configurations { } var zstdVersion = "1.5.7" -var vecVersion = "1.0.52" +var vecVersion = "1.0.53" repositories { exclusiveContent { diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java index 8836b3c69cef4..41412de6a4c41 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java @@ -39,24 +39,28 @@ enum DataType { /** * Unsigned int7. Single vector score returns results as an int. */ - INT7U(Byte.BYTES), + INT7U(Byte.BYTES * 8), + /** + * 4-bit packed nibble. Two values per byte; single vector score returns results as an int. + */ + INT4(4), /** * 1-byte int. Single vector score returns results as an int. */ - INT8(Byte.BYTES), + INT8(Byte.BYTES * 8), /** * 4-byte float. Single vector score returns results as a float. */ - FLOAT32(Float.BYTES); + FLOAT32(Float.BYTES * 8); - private final int bytes; + private final int bits; - DataType(int bytes) { - this.bytes = bytes; + DataType(int bits) { + this.bits = bits; } - public int bytes() { - return bytes; + public int bits() { + return bits; } } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java index 3621c93bfa638..9cd148a923bc2 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java @@ -123,16 +123,20 @@ private static MethodHandle bindFunction(String functionName, int capability, Fu // Only byte vectors have cosine // as floats are normalized to unit length to use dot_product instead if (f == Function.COSINE && type != DataType.INT8) continue; + // Only DOT_PRODUCT is needed for int4 — other functions are computed by + // applying correction terms on top of the raw dot product result. + if (f != Function.DOT_PRODUCT && type == DataType.INT4) continue; String typeName = switch (type) { case INT7U -> "i7u"; + case INT4 -> "i4"; case INT8 -> "i8"; case FLOAT32 -> "f32"; }; FunctionDescriptor descriptor = switch (op) { case SINGLE -> switch (type) { - case INT7U -> intSingle; + case INT7U, INT4 -> intSingle; case INT8, FLOAT32 -> floatSingle; }; case BULK -> bulk; @@ -262,8 +266,8 @@ private static Error invocationError(Throwable t, MemorySegment segment1, Memory return new AssertionError(msg, t); } - static boolean checkBulk(int elementSize, MemorySegment a, MemorySegment b, int length, int count, MemorySegment result) { - Objects.checkFromIndexSize(0L, (long) length * count * elementSize, a.byteSize()); + static boolean checkBulk(int elementBits, MemorySegment a, MemorySegment b, int length, int count, MemorySegment result) { + Objects.checkFromIndexSize(0L, (long) length * count * elementBits / 8, a.byteSize()); Objects.checkFromIndexSize(0L, length, b.byteSize()); Objects.checkFromIndexSize(0L, (long) count * Float.BYTES, result.byteSize()); return true; @@ -286,7 +290,7 @@ static boolean checkBBQBulk( } static boolean checkBulkOffsets( - int elementSize, + int elementBits, MemorySegment a, MemorySegment b, int length, @@ -295,9 +299,10 @@ static boolean checkBulkOffsets( int count, MemorySegment result ) { - if (pitch < length * elementSize) throw new IllegalArgumentException("Pitch needs to be at least " + length); + long rowBytes = (long) length * elementBits / 8; + if (pitch < rowBytes) throw new IllegalArgumentException("Pitch needs to be at least " + length); Objects.checkFromIndexSize(0L, (long) pitch * count, a.byteSize()); - Objects.checkFromIndexSize(0L, (long) length * elementSize, b.byteSize()); + Objects.checkFromIndexSize(0L, rowBytes, b.byteSize()); Objects.checkFromIndexSize(0L, (long) count * Integer.BYTES, offsets.byteSize()); Objects.checkFromIndexSize(0L, (long) count * Float.BYTES, result.byteSize()); return true; @@ -345,6 +350,16 @@ static int squareDistanceI7u(MemorySegment a, MemorySegment b, int length) { return callSingleDistanceInt(squareI7uHandle, a, b, length); } + private static final MethodHandle dotI4Handle = HANDLES.get( + new OperationSignature<>(Function.DOT_PRODUCT, DataType.INT4, Operation.SINGLE) + ); + + static int dotProductI4(MemorySegment a, MemorySegment b, int elementCount) { + Objects.checkFromIndexSize(0L, 2L * elementCount, a.byteSize()); + Objects.checkFromIndexSize(0L, elementCount, b.byteSize()); + return callSingleDistanceInt(dotI4Handle, a, b, elementCount); + } + private static final MethodHandle cosI8Handle = HANDLES.get( new OperationSignature<>(Function.COSINE, DataType.INT8, Operation.SINGLE) ); @@ -561,6 +576,10 @@ private static float applyCorrectionsDotProductBulk( type = MethodType.methodType(int.class, MemorySegment.class, MemorySegment.class, int.class); checkMethod += "I7u"; break; + case INT4: + type = MethodType.methodType(int.class, MemorySegment.class, MemorySegment.class, int.class); + checkMethod += "I4"; + break; case INT8: type = MethodType.methodType(float.class, MemorySegment.class, MemorySegment.class, int.class); checkMethod += "I8"; @@ -623,7 +642,7 @@ private static float applyCorrectionsDotProductBulk( ) ); yield MethodHandles.guardWithTest( - MethodHandles.insertArguments(checkMethod, 0, dt.bytes()), + MethodHandles.insertArguments(checkMethod, 0, dt.bits()), op.getValue(), MethodHandles.empty(op.getValue().type()) ); @@ -674,7 +693,7 @@ private static float applyCorrectionsDotProductBulk( ) ); yield MethodHandles.guardWithTest( - MethodHandles.insertArguments(checkMethod, 0, dt.bytes()), + MethodHandles.insertArguments(checkMethod, 0, dt.bits()), op.getValue(), MethodHandles.empty(op.getValue().type()) ); diff --git a/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4Tests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4Tests.java new file mode 100644 index 0000000000000..525c82db46ff0 --- /dev/null +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4Tests.java @@ -0,0 +1,334 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.nativeaccess.jdk; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.nativeaccess.VectorSimilarityFunctions; +import org.elasticsearch.nativeaccess.VectorSimilarityFunctionsTests; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.util.List; + +import static java.lang.foreign.ValueLayout.JAVA_FLOAT_UNALIGNED; +import static org.elasticsearch.nativeaccess.Int4TestUtils.dotProductI4SinglePacked; +import static org.elasticsearch.nativeaccess.Int4TestUtils.packNibbles; +import static org.hamcrest.Matchers.containsString; + +/** + * Low-level tests for native Int4 (packed-nibble) dot product functions. + * + *

Int4 vectors are asymmetric: the "unpacked" query has {@code 2 * packedLen} bytes + * (one value per byte, range 0-15), while the "packed" document has {@code packedLen} bytes + * (two nibbles per byte). The third argument to native functions is {@code packedLen}, + * not the logical dimension count. + */ +public class JDKVectorLibraryInt4Tests extends VectorSimilarityFunctionsTests { + + static final byte MIN_INT4_VALUE = 0; + static final byte MAX_INT4_VALUE = 0x0F; + + public JDKVectorLibraryInt4Tests(VectorSimilarityFunctions.Function function, int size) { + super(function, size); + } + + @ParametersFactory + public static Iterable parametersFactory() { + List baseParams = CollectionUtils.iterableAsArrayList(VectorSimilarityFunctionsTests.parametersFactory()); + // Int4 only supports dot product + baseParams.removeIf(os -> os[0] != VectorSimilarityFunctions.Function.DOT_PRODUCT); + // Int4 requires even dimensions (two nibbles per packed byte) + baseParams.removeIf(os -> (Integer) os[1] % 2 != 0); + return baseParams; + } + + @BeforeClass + public static void beforeClass() { + VectorSimilarityFunctionsTests.setup(); + } + + @AfterClass + public static void afterClass() { + VectorSimilarityFunctionsTests.cleanup(); + } + + public void testInt4BinaryVectors() { + assumeTrue(notSupportedMsg(), supported()); + final int dims = size; + final int packedLen = dims / 2; + final int numVecs = randomIntBetween(2, 101); + + var unpackedValues = new byte[numVecs][dims]; + var packedValues = new byte[numVecs][packedLen]; + var unpackedSegment = arena.allocate((long) dims * numVecs); + var packedSegment = arena.allocate((long) packedLen * numVecs); + + for (int i = 0; i < numVecs; i++) { + randomBytesBetween(unpackedValues[i], MIN_INT4_VALUE, MAX_INT4_VALUE); + packedValues[i] = packNibbles(unpackedValues[i]); + MemorySegment.copy(MemorySegment.ofArray(unpackedValues[i]), 0L, unpackedSegment, (long) i * dims, dims); + MemorySegment.copy(MemorySegment.ofArray(packedValues[i]), 0L, packedSegment, (long) i * packedLen, packedLen); + } + + final int loopTimes = 1000; + for (int i = 0; i < loopTimes; i++) { + int first = randomInt(numVecs - 1); + int second = randomInt(numVecs - 1); + var nativeUnpacked = unpackedSegment.asSlice((long) first * dims, dims); + var nativePacked = packedSegment.asSlice((long) second * packedLen, packedLen); + + int expected = dotProductI4SinglePacked(unpackedValues[first], packedValues[second]); + assertEquals(expected, similarity(nativeUnpacked, nativePacked, packedLen)); + + if (supportsHeapSegments()) { + var heapUnpacked = MemorySegment.ofArray(unpackedValues[first]); + var heapPacked = MemorySegment.ofArray(packedValues[second]); + assertEquals(expected, similarity(heapUnpacked, heapPacked, packedLen)); + assertEquals(expected, similarity(nativeUnpacked, heapPacked, packedLen)); + assertEquals(expected, similarity(heapUnpacked, nativePacked, packedLen)); + + // trivial bulk with a single vector + float[] bulkScore = new float[1]; + similarityBulk(nativePacked, nativeUnpacked, packedLen, 1, MemorySegment.ofArray(bulkScore)); + assertEquals(expected, bulkScore[0], 0f); + } + } + } + + public void testInt4Bulk() { + assumeTrue(notSupportedMsg(), supported()); + final int dims = size; + final int packedLen = dims / 2; + final int numVecs = randomIntBetween(2, 101); + + var unpackedValues = new byte[numVecs][dims]; + var packedValues = new byte[numVecs][packedLen]; + var packedSegment = arena.allocate((long) packedLen * numVecs); + + for (int i = 0; i < numVecs; i++) { + randomBytesBetween(unpackedValues[i], MIN_INT4_VALUE, MAX_INT4_VALUE); + packedValues[i] = packNibbles(unpackedValues[i]); + MemorySegment.copy(MemorySegment.ofArray(packedValues[i]), 0L, packedSegment, (long) i * packedLen, packedLen); + } + + int queryOrd = randomInt(numVecs - 1); + float[] expectedScores = new float[numVecs]; + scalarSimilarityBulk(unpackedValues[queryOrd], packedValues, expectedScores); + + var nativeQuerySeg = MemorySegment.ofArray(unpackedValues[queryOrd]); + var bulkScoresSeg = arena.allocate((long) numVecs * Float.BYTES); + similarityBulk(packedSegment, nativeQuerySeg, packedLen, numVecs, bulkScoresSeg); + assertScoresEquals(expectedScores, bulkScoresSeg); + + if (supportsHeapSegments()) { + float[] bulkScores = new float[numVecs]; + similarityBulk(packedSegment, nativeQuerySeg, packedLen, numVecs, MemorySegment.ofArray(bulkScores)); + assertArrayEquals(expectedScores, bulkScores, 0f); + } + } + + public void testInt4BulkWithOffsets() { + assumeTrue(notSupportedMsg(), supported()); + final int dims = size; + final int packedLen = dims / 2; + final int numVecs = randomIntBetween(2, 101); + + var offsets = new int[numVecs]; + var unpackedValues = new byte[numVecs][dims]; + var packedValues = new byte[numVecs][packedLen]; + var packedSegment = arena.allocate((long) packedLen * numVecs); + var offsetsSegment = arena.allocate((long) numVecs * Integer.BYTES); + + for (int i = 0; i < numVecs; i++) { + offsets[i] = randomInt(numVecs - 1); + offsetsSegment.setAtIndex(ValueLayout.JAVA_INT, i, offsets[i]); + randomBytesBetween(unpackedValues[i], MIN_INT4_VALUE, MAX_INT4_VALUE); + packedValues[i] = packNibbles(unpackedValues[i]); + MemorySegment.copy(packedValues[i], 0, packedSegment, ValueLayout.JAVA_BYTE, (long) i * packedLen, packedLen); + } + + int queryOrd = randomInt(numVecs - 1); + float[] expectedScores = new float[numVecs]; + scalarSimilarityBulkWithOffsets(unpackedValues[queryOrd], packedValues, offsets, expectedScores); + + var nativeQuerySeg = MemorySegment.ofArray(unpackedValues[queryOrd]); + var bulkScoresSeg = arena.allocate((long) numVecs * Float.BYTES); + + similarityBulkWithOffsets(packedSegment, nativeQuerySeg, packedLen, packedLen, offsetsSegment, numVecs, bulkScoresSeg); + assertScoresEquals(expectedScores, bulkScoresSeg); + } + + public void testInt4BulkWithOffsetsAndPitch() { + assumeTrue(notSupportedMsg(), supported()); + final int dims = size; + final int packedLen = dims / 2; + final int numVecs = randomIntBetween(2, 101); + + var offsets = new int[numVecs]; + var unpackedValues = new byte[numVecs][dims]; + var packedValues = new byte[numVecs][packedLen]; + + int pitch = packedLen + Float.BYTES; + var packedSegment = arena.allocate((long) numVecs * pitch); + var offsetsSegment = arena.allocate((long) numVecs * Integer.BYTES); + + for (int i = 0; i < numVecs; i++) { + offsets[i] = randomInt(numVecs - 1); + offsetsSegment.setAtIndex(ValueLayout.JAVA_INT, i, offsets[i]); + randomBytesBetween(unpackedValues[i], MIN_INT4_VALUE, MAX_INT4_VALUE); + packedValues[i] = packNibbles(unpackedValues[i]); + MemorySegment.copy(packedValues[i], 0, packedSegment, ValueLayout.JAVA_BYTE, (long) i * pitch, packedLen); + } + + int queryOrd = randomInt(numVecs - 1); + float[] expectedScores = new float[numVecs]; + scalarSimilarityBulkWithOffsets(unpackedValues[queryOrd], packedValues, offsets, expectedScores); + + var nativeQuerySeg = MemorySegment.ofArray(unpackedValues[queryOrd]); + var bulkScoresSeg = arena.allocate((long) numVecs * Float.BYTES); + + similarityBulkWithOffsets(packedSegment, nativeQuerySeg, packedLen, pitch, offsetsSegment, numVecs, bulkScoresSeg); + assertScoresEquals(expectedScores, bulkScoresSeg); + } + + public void testInt4BulkWithOffsetsHeapSegments() { + assumeTrue(notSupportedMsg(), supported()); + assumeTrue("Requires support for heap MemorySegments", supportsHeapSegments()); + final int dims = size; + final int packedLen = dims / 2; + final int numVecs = randomIntBetween(2, 101); + + var offsets = new int[numVecs]; + var unpackedValues = new byte[numVecs][dims]; + var packedValues = new byte[numVecs][packedLen]; + var packedSegment = arena.allocate((long) packedLen * numVecs); + + for (int i = 0; i < numVecs; i++) { + offsets[i] = randomInt(numVecs - 1); + randomBytesBetween(unpackedValues[i], MIN_INT4_VALUE, MAX_INT4_VALUE); + packedValues[i] = packNibbles(unpackedValues[i]); + MemorySegment.copy(MemorySegment.ofArray(packedValues[i]), 0L, packedSegment, (long) i * packedLen, packedLen); + } + + int queryOrd = randomInt(numVecs - 1); + float[] expectedScores = new float[numVecs]; + scalarSimilarityBulkWithOffsets(unpackedValues[queryOrd], packedValues, offsets, expectedScores); + + float[] bulkScores = new float[numVecs]; + similarityBulkWithOffsets( + packedSegment, + MemorySegment.ofArray(unpackedValues[queryOrd]), + packedLen, + packedLen, + MemorySegment.ofArray(offsets), + numVecs, + MemorySegment.ofArray(bulkScores) + ); + assertArrayEquals(expectedScores, bulkScores, 0f); + } + + public void testIllegalDims() { + assumeTrue(notSupportedMsg(), supported()); + int packedLen = size / 2; + var unpacked = arena.allocate((long) size); + var packed = arena.allocate((long) packedLen + 1); + + var ex = expectThrows(IOOBE, () -> similarity(unpacked, packed.asSlice(0L, packedLen), packedLen + 1)); + assertThat(ex.getMessage(), containsString("out of bounds for length")); + + ex = expectThrows(IOOBE, () -> similarity(unpacked, packed.asSlice(0L, packedLen), -1)); + assertThat(ex.getMessage(), containsString("out of bounds for length")); + } + + public void testBulkIllegalDims() { + assumeTrue(notSupportedMsg(), supported()); + int packedLen = size / 2; + var segA = arena.allocate((long) packedLen - 1); + var segB = arena.allocate(size); + var segS = arena.allocate((long) size * Float.BYTES); + + Exception ex = expectThrows(IOOBE, () -> similarityBulk(segA, segB, packedLen, 4, segS)); + assertThat(ex.getMessage(), containsString("out of bounds for length")); + + ex = expectThrows(IOOBE, () -> similarityBulk(segA, segB, packedLen, -1, segS)); + assertThat(ex.getMessage(), containsString("out of bounds for length")); + + ex = expectThrows(IOOBE, () -> similarityBulk(segA, segB, -1, 3, segS)); + assertThat(ex.getMessage(), containsString("out of bounds for length")); + + var tooSmall = arena.allocate((long) 3 * Float.BYTES - 1); + ex = expectThrows(IOOBE, () -> similarityBulk(segA, segB, packedLen, 3, tooSmall)); + assertThat(ex.getMessage(), containsString("out of bounds for length")); + } + + int similarity(MemorySegment unpacked, MemorySegment packed, int packedLen) { + try { + return (int) getVectorDistance().getHandle( + function, + VectorSimilarityFunctions.DataType.INT4, + VectorSimilarityFunctions.Operation.SINGLE + ).invokeExact(unpacked, packed, packedLen); + } catch (Throwable t) { + throw rethrow(t); + } + } + + void similarityBulk(MemorySegment packedDocs, MemorySegment unpackedQuery, int packedLen, int count, MemorySegment result) { + try { + getVectorDistance().getHandle(function, VectorSimilarityFunctions.DataType.INT4, VectorSimilarityFunctions.Operation.BULK) + .invokeExact(packedDocs, unpackedQuery, packedLen, count, result); + } catch (Throwable t) { + throw rethrow(t); + } + } + + void similarityBulkWithOffsets( + MemorySegment packedDocs, + MemorySegment unpackedQuery, + int packedLen, + int pitch, + MemorySegment offsets, + int count, + MemorySegment result + ) { + try { + getVectorDistance().getHandle( + function, + VectorSimilarityFunctions.DataType.INT4, + VectorSimilarityFunctions.Operation.BULK_OFFSETS + ).invokeExact(packedDocs, unpackedQuery, packedLen, pitch, offsets, count, result); + } catch (Throwable t) { + throw rethrow(t); + } + } + + static void scalarSimilarityBulk(byte[] unpackedQuery, byte[][] packedData, float[] scores) { + for (int i = 0; i < packedData.length; i++) { + scores[i] = dotProductI4SinglePacked(unpackedQuery, packedData[i]); + } + } + + static void scalarSimilarityBulkWithOffsets(byte[] unpackedQuery, byte[][] packedData, int[] offsets, float[] scores) { + for (int i = 0; i < packedData.length; i++) { + scores[i] = dotProductI4SinglePacked(unpackedQuery, packedData[offsets[i]]); + } + } + + static void assertScoresEquals(float[] expectedScores, MemorySegment expectedScoresSeg) { + assert expectedScores.length == (expectedScoresSeg.byteSize() / Float.BYTES); + for (int i = 0; i < expectedScores.length; i++) { + assertEquals(expectedScores[i], expectedScoresSeg.get(JAVA_FLOAT_UNALIGNED, (long) i * Float.BYTES), 0f); + } + } +} diff --git a/libs/native/src/testFixtures/java/org/elasticsearch/nativeaccess/Int4TestUtils.java b/libs/native/src/testFixtures/java/org/elasticsearch/nativeaccess/Int4TestUtils.java new file mode 100644 index 0000000000000..d7bf59d38caeb --- /dev/null +++ b/libs/native/src/testFixtures/java/org/elasticsearch/nativeaccess/Int4TestUtils.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.nativeaccess; + +/** + * Shared test utilities for Int4 packed-nibble vector operations. + * + *

Int4 vectors use two representations: + *

+ *

+ * The unpacked input comes from {@code OptimizedScalarQuantizer#scalarQuantize}, which quantizes a float + * vector into one byte per element in natural order: unpacked = [v0, v1, v2, ..., v_{N-1}] where N = dims. + *

+ * The packed format pairs elements that are packedLength ({@param unpacked} length / 2) apart. For example, + * with dims=8, unpacked.length is 8, and packedLength is 4: + * - {@code packed[0] = (v0 << 4) | v4} + * - {@code packed[1] = (v1 << 4) | v5} + * - {@code packed[2] = (v2 << 4) | v6} + * - {@code packed[3] = (v3 << 4) | v7} + *

+ * Or, visually, + * UNPACKED (8 bytes, natural vector order, one 4-bit value per byte): + * index: 0 1 2 3 4 5 6 7 + * [v0] [v1] [v2] [v3] [v4] [v5] [v6] [v7] + * PACKED (4 bytes, on disk, two 4-bit values per byte): + * index: 0 1 2 3 + * [v0 | v4] [v1 | v5] [v2 | v6] [v3 | v7] + * hi lo hi lo hi lo hi lo + * 7..4 3..0 7..4 3..0 7..4 3..0 7..4 3..0 + */ +public final class Int4TestUtils { + + private Int4TestUtils() {} + + /** + * Packs unpacked int4 values (one value per byte) into Lucene nibble-packed format (two values per byte) + * written by {@code Lucene104ScalarQuantizedVectorsWriter} (ScalarEncoding#PACKED_NIBBLE format). + * The input layout is {@code [high0, high1, ..., low0, low1, ...]} with length {@code 2 * packedLen}. + */ + public static byte[] packNibbles(byte[] unpacked) { + int packedLength = unpacked.length / 2; + byte[] packed = new byte[packedLength]; + for (int i = 0; i < packedLength; i++) { + packed[i] = (byte) ((unpacked[i] << 4) | (unpacked[i + packedLength] & 0x0F)); + } + return packed; + } + + /** + * Unpacks "nibble-packed" int4 values (two values per byte) into the unpacked int4 format (byte[], one value per byte). + * @param packed the packed bytes (each holding two 4-bit values) + * @param dims the total number of 4-bit elements ({@code 2 * packed.length}) + */ + public static byte[] unpackNibbles(byte[] packed, int dims) { + byte[] unpacked = new byte[dims]; + int packedLen = packed.length; + for (int i = 0; i < packedLen; i++) { + unpacked[i] = (byte) ((packed[i] & 0xFF) >>> 4); + unpacked[i + packedLen] = (byte) (packed[i] & 0x0F); + } + return unpacked; + } + + /** + * Computes the dot product between an unpacked query vector and a packed document vector, + * matching the native {@code doti4_inner} implementation. + */ + public static int dotProductI4SinglePacked(byte[] unpacked, byte[] packed) { + int total = 0; + for (int i = 0; i < packed.length; i++) { + byte packedByte = packed[i]; + total += ((packedByte & 0xFF) >> 4) * unpacked[i]; + total += (packedByte & 0x0F) * unpacked[i + packed.length]; + } + return total; + } +} diff --git a/libs/simdvec/build.gradle b/libs/simdvec/build.gradle index d9dcb25a93fbb..62dfc3672a675 100644 --- a/libs/simdvec/build.gradle +++ b/libs/simdvec/build.gradle @@ -20,6 +20,7 @@ dependencies { implementation "org.apache.lucene:lucene-core:${versions.lucene}" testImplementation(testArtifact(project(':x-pack:plugin:searchable-snapshots'))) + testImplementation(testFixtures(project(':libs:native'))) testImplementation(project(":test:framework")) { exclude group: 'org.elasticsearch', module: 'native' } diff --git a/libs/simdvec/native/publish_vec_binaries.sh b/libs/simdvec/native/publish_vec_binaries.sh index d68da9afb268f..c0d047156abf9 100755 --- a/libs/simdvec/native/publish_vec_binaries.sh +++ b/libs/simdvec/native/publish_vec_binaries.sh @@ -20,7 +20,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then exit 1; fi -VERSION="1.0.52" +VERSION="1.0.53" ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}" TEMP=$(mktemp -d) diff --git a/libs/simdvec/native/src/vec/c/aarch64/vec_i4_1.cpp b/libs/simdvec/native/src/vec/c/aarch64/vec_i4_1.cpp new file mode 100644 index 0000000000000..6522a36639901 --- /dev/null +++ b/libs/simdvec/native/src/vec/c/aarch64/vec_i4_1.cpp @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +// Scalar implementations for int4 packed-nibble vector operations. +// The "unpacked" vector has 2*packed_len bytes (high nibbles in [0..packed_len), +// low nibbles in [packed_len..2*packed_len)). The "packed" vector has packed_len +// bytes, each holding two 4-bit values. + +#include +#include "vec.h" +#include "vec_common.h" + +static inline int32_t doti4_inner(const int8_t* unpacked, const uint8_t* packed, int32_t packed_len) { + int32_t total = 0; + for (int32_t i = 0; i < packed_len; i++) { + uint8_t p = packed[i]; + total += (p >> 4) * unpacked[i]; + total += (p & 0x0F) * unpacked[i + packed_len]; + } + return total; +} + +EXPORT int32_t vec_doti4(const int8_t* unpacked, const uint8_t* packed, int32_t packed_len) { + return doti4_inner(unpacked, packed, packed_len); +} + +EXPORT void vec_doti4_bulk(const uint8_t* a, const int8_t* b, int32_t packed_len, int32_t count, f32_t* results) { + for (int c = 0; c < count; c++) { + results[c] = (f32_t)doti4_inner(b, a + (int64_t)c * packed_len, packed_len); + } +} + +EXPORT void vec_doti4_bulk_offsets( + const uint8_t* a, + const int8_t* b, + int32_t packed_len, + int32_t pitch, + const int32_t* offsets, + int32_t count, + f32_t* results +) { + for (int c = 0; c < count; c++) { + const uint8_t* doc = a + (int64_t)offsets[c] * pitch; + results[c] = (f32_t)doti4_inner(b, doc, packed_len); + } +} diff --git a/libs/simdvec/native/src/vec/c/amd64/vec_i4_1.cpp b/libs/simdvec/native/src/vec/c/amd64/vec_i4_1.cpp new file mode 100644 index 0000000000000..6522a36639901 --- /dev/null +++ b/libs/simdvec/native/src/vec/c/amd64/vec_i4_1.cpp @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +// Scalar implementations for int4 packed-nibble vector operations. +// The "unpacked" vector has 2*packed_len bytes (high nibbles in [0..packed_len), +// low nibbles in [packed_len..2*packed_len)). The "packed" vector has packed_len +// bytes, each holding two 4-bit values. + +#include +#include "vec.h" +#include "vec_common.h" + +static inline int32_t doti4_inner(const int8_t* unpacked, const uint8_t* packed, int32_t packed_len) { + int32_t total = 0; + for (int32_t i = 0; i < packed_len; i++) { + uint8_t p = packed[i]; + total += (p >> 4) * unpacked[i]; + total += (p & 0x0F) * unpacked[i + packed_len]; + } + return total; +} + +EXPORT int32_t vec_doti4(const int8_t* unpacked, const uint8_t* packed, int32_t packed_len) { + return doti4_inner(unpacked, packed, packed_len); +} + +EXPORT void vec_doti4_bulk(const uint8_t* a, const int8_t* b, int32_t packed_len, int32_t count, f32_t* results) { + for (int c = 0; c < count; c++) { + results[c] = (f32_t)doti4_inner(b, a + (int64_t)c * packed_len, packed_len); + } +} + +EXPORT void vec_doti4_bulk_offsets( + const uint8_t* a, + const int8_t* b, + int32_t packed_len, + int32_t pitch, + const int32_t* offsets, + int32_t count, + f32_t* results +) { + for (int c = 0; c < count; c++) { + const uint8_t* doc = a + (int64_t)offsets[c] * pitch; + results[c] = (f32_t)doti4_inner(b, doc, packed_len); + } +} diff --git a/libs/simdvec/native/src/vec/headers/vec_common.h b/libs/simdvec/native/src/vec/headers/vec_common.h index 6e4a8fc9868ae..379430b2e970d 100644 --- a/libs/simdvec/native/src/vec/headers/vec_common.h +++ b/libs/simdvec/native/src/vec/headers/vec_common.h @@ -5,6 +5,7 @@ #include #include #include +#include template static inline uintptr_t align_downwards(const void* ptr) { diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java index 1a848eb659084..38583c8c221ea 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java @@ -144,4 +144,42 @@ Optional getInt7uOSQVectorScorer( float additionalCorrection, int quantizedComponentSum ); + + /** + * Returns an optional containing an int4 packed-nibble vector score supplier + * for the given parameters, or an empty optional if a scorer is not supported. + * + * @param similarityType the similarity type + * @param input the index input containing the vector data + * @param values the random access vector values + * @return an optional containing the vector scorer supplier, or empty + */ + Optional getInt4VectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values + ); + + /** + * Returns an optional containing an int4 packed-nibble query-time vector scorer + * for the given parameters, or an empty optional if a scorer is not supported. + * + * @param sim the similarity function + * @param values the quantized vector values + * @param unpackedQuery the quantized query bytes (one byte per dimension, 0-15) + * @param lowerInterval query corrective term + * @param upperInterval query corrective term + * @param additionalCorrection query corrective term + * @param quantizedComponentSum query corrective term + * @return an optional containing the vector scorer, or empty + */ + Optional getInt4VectorScorer( + VectorSimilarityFunction sim, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values, + byte[] unpackedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ); } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java index 4a237c8ffbae5..11bbe607a8cf4 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java @@ -97,4 +97,26 @@ public Optional getInt7uOSQVectorScorer( ) { throw new UnsupportedOperationException("should not reach here"); } + + @Override + public Optional getInt4VectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values + ) { + throw new UnsupportedOperationException("should not reach here"); + } + + @Override + public Optional getInt4VectorScorer( + VectorSimilarityFunction sim, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values, + byte[] unpackedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + throw new UnsupportedOperationException("should not reach here"); + } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java index 6450811ff82d1..bbbd32a782bb2 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java @@ -23,6 +23,8 @@ import org.elasticsearch.simdvec.internal.ByteVectorScorerSupplier; import org.elasticsearch.simdvec.internal.FloatVectorScorer; import org.elasticsearch.simdvec.internal.FloatVectorScorerSupplier; +import org.elasticsearch.simdvec.internal.Int4VectorScorer; +import org.elasticsearch.simdvec.internal.Int4VectorScorerSupplier; import org.elasticsearch.simdvec.internal.Int7SQVectorScorer; import org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier; import org.elasticsearch.simdvec.internal.Int7uOSQVectorScorer; @@ -162,6 +164,39 @@ public Optional getInt7uOSQVectorScorer( ); } + @Override + public Optional getInt4VectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values + ) { + input = FilterIndexInput.unwrapOnlyTest(input); + input = MemorySegmentAccessInputAccess.unwrap(input); + checkInvariants(values.size(), values.dimension() / 2, input); + return Optional.of(new Int4VectorScorerSupplier(input, values, similarityType)); + } + + @Override + public Optional getInt4VectorScorer( + VectorSimilarityFunction sim, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values, + byte[] unpackedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + return Int4VectorScorer.create( + sim, + values, + unpackedQuery, + lowerInterval, + upperInterval, + additionalCorrection, + quantizedComponentSum + ); + } + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { if (input.length() < (long) vectorByteLength * maxOrd) { throw new IllegalArgumentException("input length is less than expected vector data"); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int4Corrections.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int4Corrections.java new file mode 100644 index 0000000000000..be9ef0c1b9f87 --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int4Corrections.java @@ -0,0 +1,223 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorScorer; +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.simdvec.VectorSimilarityType; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +/** + * Shared correction formulas for int4 packed-nibble scoring. Used by both the + * scorer supplier (ordinal-vs-ordinal) and the query-time scorer paths. + * Correction formulas are the same as in {@link Lucene104ScalarQuantizedVectorScorer}, specialized for the INT4 case. + */ +final class Int4Corrections { + + static final float LIMIT_SCALE = 1f / ((1 << 4) - 1); + + @FunctionalInterface + interface SingleCorrection { + float apply( + QuantizedByteVectorValues values, + int dims, + float rawScore, + int ord, + float qLower, + float qUpper, + float qAdditional, + int qComponentSum + ) throws IOException; + } + + @FunctionalInterface + interface BulkCorrection { + float apply( + QuantizedByteVectorValues values, + int dims, + MemorySegment scores, + MemorySegment ordinals, + int numNodes, + float qLower, + float qUpper, + float qAdditional, + int qComponentSum + ) throws IOException; + } + + static SingleCorrection singleCorrectionFor(VectorSimilarityType type) { + return switch (type) { + case COSINE, DOT_PRODUCT -> Int4Corrections::dotProduct; + case EUCLIDEAN -> Int4Corrections::euclidean; + case MAXIMUM_INNER_PRODUCT -> Int4Corrections::maxInnerProduct; + }; + } + + static BulkCorrection bulkCorrectionFor(VectorSimilarityType type) { + return switch (type) { + case COSINE, DOT_PRODUCT -> Int4Corrections::dotProductBulk; + case EUCLIDEAN -> Int4Corrections::euclideanBulk; + case MAXIMUM_INNER_PRODUCT -> Int4Corrections::maxInnerProductBulk; + }; + } + + private Int4Corrections() {} + + static float dotProduct( + QuantizedByteVectorValues values, + int dims, + float rawScore, + int ord, + float queryLower, + float queryUpper, + float queryAdditional, + int queryComponentSum + ) throws IOException { + var ct = values.getCorrectiveTerms(ord); + float ax = ct.lowerInterval(); + float lx = (ct.upperInterval() - ax) * LIMIT_SCALE; + float ay = queryLower; + float ly = (queryUpper - ay) * LIMIT_SCALE; + float score = ax * ay * dims + ay * lx * ct.quantizedComponentSum() + ax * ly * queryComponentSum + lx * ly * rawScore; + score += queryAdditional + ct.additionalCorrection() - values.getCentroidDP(); + return VectorUtil.normalizeToUnitInterval(Math.clamp(score, -1, 1)); + } + + static float dotProductBulk( + QuantizedByteVectorValues values, + int dims, + MemorySegment scoreSeg, + MemorySegment ordinalsSeg, + int numNodes, + float queryLower, + float queryUpper, + float queryAdditional, + int queryComponentSum + ) throws IOException { + float ay = queryLower; + float ly = (queryUpper - ay) * LIMIT_SCALE; + float max = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + float raw = scoreSeg.getAtIndex(ValueLayout.JAVA_FLOAT, i); + int nodeOrd = ordinalsSeg.getAtIndex(ValueLayout.JAVA_INT, i); + var ct = values.getCorrectiveTerms(nodeOrd); + float ax = ct.lowerInterval(); + float lx = (ct.upperInterval() - ax) * LIMIT_SCALE; + float score = ax * ay * dims + ay * lx * ct.quantizedComponentSum() + ax * ly * queryComponentSum + lx * ly * raw; + score += queryAdditional + ct.additionalCorrection() - values.getCentroidDP(); + float normalized = VectorUtil.normalizeToUnitInterval(Math.clamp(score, -1, 1)); + scoreSeg.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalized); + max = Math.max(max, normalized); + } + return max; + } + + static float euclidean( + QuantizedByteVectorValues values, + int dims, + float rawScore, + int ord, + float queryLower, + float queryUpper, + float queryAdditional, + int queryComponentSum + ) throws IOException { + var ct = values.getCorrectiveTerms(ord); + float ax = ct.lowerInterval(); + float lx = (ct.upperInterval() - ax) * LIMIT_SCALE; + float ay = queryLower; + float ly = (queryUpper - ay) * LIMIT_SCALE; + float score = ax * ay * dims + ay * lx * ct.quantizedComponentSum() + ax * ly * queryComponentSum + lx * ly * rawScore; + score = queryAdditional + ct.additionalCorrection() - 2 * score; + return VectorUtil.normalizeDistanceToUnitInterval(Math.max(score, 0f)); + } + + static float euclideanBulk( + QuantizedByteVectorValues values, + int dims, + MemorySegment scoreSeg, + MemorySegment ordinalsSeg, + int numNodes, + float queryLower, + float queryUpper, + float queryAdditional, + int queryComponentSum + ) throws IOException { + float ay = queryLower; + float ly = (queryUpper - ay) * LIMIT_SCALE; + float max = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + float raw = scoreSeg.getAtIndex(ValueLayout.JAVA_FLOAT, i); + int nodeOrd = ordinalsSeg.getAtIndex(ValueLayout.JAVA_INT, i); + var ct = values.getCorrectiveTerms(nodeOrd); + float ax = ct.lowerInterval(); + float lx = (ct.upperInterval() - ax) * LIMIT_SCALE; + float score = ax * ay * dims + ay * lx * ct.quantizedComponentSum() + ax * ly * queryComponentSum + lx * ly * raw; + score = queryAdditional + ct.additionalCorrection() - 2 * score; + float normalized = VectorUtil.normalizeDistanceToUnitInterval(Math.max(score, 0f)); + scoreSeg.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalized); + max = Math.max(max, normalized); + } + return max; + } + + static float maxInnerProduct( + QuantizedByteVectorValues values, + int dims, + float rawScore, + int ord, + float queryLower, + float queryUpper, + float queryAdditional, + int queryComponentSum + ) throws IOException { + var ct = values.getCorrectiveTerms(ord); + float ax = ct.lowerInterval(); + float lx = (ct.upperInterval() - ax) * LIMIT_SCALE; + float ay = queryLower; + float ly = (queryUpper - ay) * LIMIT_SCALE; + float score = ax * ay * dims + ay * lx * ct.quantizedComponentSum() + ax * ly * queryComponentSum + lx * ly * rawScore; + score += queryAdditional + ct.additionalCorrection() - values.getCentroidDP(); + return VectorUtil.scaleMaxInnerProductScore(score); + } + + static float maxInnerProductBulk( + QuantizedByteVectorValues values, + int dims, + MemorySegment scoreSeg, + MemorySegment ordinalsSeg, + int numNodes, + float queryLower, + float queryUpper, + float queryAdditional, + int queryComponentSum + ) throws IOException { + float ay = queryLower; + float ly = (queryUpper - ay) * LIMIT_SCALE; + float max = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + float raw = scoreSeg.getAtIndex(ValueLayout.JAVA_FLOAT, i); + int nodeOrd = ordinalsSeg.getAtIndex(ValueLayout.JAVA_INT, i); + var ct = values.getCorrectiveTerms(nodeOrd); + float ax = ct.lowerInterval(); + float lx = (ct.upperInterval() - ax) * LIMIT_SCALE; + float score = ax * ay * dims + ay * lx * ct.quantizedComponentSum() + ax * ly * queryComponentSum + lx * ly * raw; + score += queryAdditional + ct.additionalCorrection() - values.getCentroidDP(); + float normalizedScore = VectorUtil.scaleMaxInnerProductScore(score); + scoreSeg.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalizedScore); + max = Math.max(max, normalizedScore); + } + return max; + } +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int4VectorScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int4VectorScorer.java new file mode 100644 index 0000000000000..fd7c89227dbaa --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int4VectorScorer.java @@ -0,0 +1,261 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.elasticsearch.simdvec.MemorySegmentAccessInputAccess; +import org.elasticsearch.simdvec.VectorSimilarityType; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.util.Optional; + +import static org.elasticsearch.simdvec.internal.Similarities.dotProductI4; +import static org.elasticsearch.simdvec.internal.Similarities.dotProductI4BulkWithOffsets; + +/** + * Int4 packed-nibble query-time scorer. The float query is quantized externally + * and passed in as unpacked bytes (one byte per dimension, 0-15 range) along + * with corrective terms. Each stored vector is {@code dims/2} packed bytes + * followed by corrective terms (3 floats + 1 int). + */ +public final class Int4VectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + + private static final boolean SUPPORTS_HEAP_SEGMENTS = Runtime.version().feature() >= 22; + + private final ScorerImpl scorerImpl; + private final QueryContext query; + + /** + * Creates an int4 query-time scorer if the input supports efficient access. + * + * @param sim the similarity function + * @param values the quantized vector values + * @param unpackedQuery the quantized query (dims bytes, one per dimension, 0-15) + * @param lowerInterval query corrective term + * @param upperInterval query corrective term + * @param additionalCorrection query corrective term + * @param quantizedComponentSum query corrective term + * @return an optional scorer or empty if the input doesn't support native access + */ + public static Optional create( + VectorSimilarityFunction sim, + QuantizedByteVectorValues values, + byte[] unpackedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + IndexInput input = values.getSlice(); + if (input == null) { + return Optional.empty(); + } + input = FilterIndexInput.unwrapOnlyTest(input); + input = MemorySegmentAccessInputAccess.unwrap(input); + return Optional.of( + new Int4VectorScorer( + input, + values, + VectorSimilarityType.of(sim), + unpackedQuery, + lowerInterval, + upperInterval, + additionalCorrection, + quantizedComponentSum + ) + ); + } + + Int4VectorScorer( + IndexInput input, + QuantizedByteVectorValues values, + VectorSimilarityType similarityType, + byte[] unpackedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + super(values); + IndexInputUtils.checkInputType(input); + int dims = values.dimension(); + int packedDims = dims / 2; + long vectorPitch = packedDims + 3L * Float.BYTES + Integer.BYTES; + + this.scorerImpl = new ScorerImpl( + input, + values, + dims, + packedDims, + vectorPitch, + Int4Corrections.singleCorrectionFor(similarityType), + Int4Corrections.bulkCorrectionFor(similarityType) + ); + + final MemorySegment unpackedQuerySegment; + if (SUPPORTS_HEAP_SEGMENTS) { + unpackedQuerySegment = MemorySegment.ofArray(unpackedQuery); + } else { + unpackedQuerySegment = Arena.ofAuto().allocate(unpackedQuery.length, 32); + MemorySegment.copy(unpackedQuery, 0, unpackedQuerySegment, ValueLayout.JAVA_BYTE, 0, unpackedQuery.length); + } + + this.query = new QueryContext(lowerInterval, upperInterval, additionalCorrection, quantizedComponentSum, unpackedQuerySegment); + } + + @Override + public float score(int node) throws IOException { + return scorerImpl.scoreWithQuery(query, node); + } + + @Override + public float bulkScore(int[] ordinals, float[] scores, int numNodes) throws IOException { + return scorerImpl.bulkScoreWithQuery(query, ordinals, scores, numNodes); + } + + /** + * Shared scoring implementation used by both {@link Int4VectorScorer} (query-time) and + * {@link Int4VectorScorerSupplier} (graph-build / reranking). + * Not thread-safe under all conditions (due to mutable state (scratch) used by IndexInput): + * each supplier/scorer should own its own instance. + */ + static class ScorerImpl { + private final IndexInput input; + private final QuantizedByteVectorValues values; + private final int dims; + private final int packedDims; + private final long vectorPitch; + private final Int4Corrections.SingleCorrection correction; + private final Int4Corrections.BulkCorrection bulkCorrection; + private byte[] scratch; + + ScorerImpl( + IndexInput input, + QuantizedByteVectorValues values, + int dims, + int packedDims, + long vectorPitch, + Int4Corrections.SingleCorrection correction, + Int4Corrections.BulkCorrection bulkCorrection + ) { + this.input = input; + this.values = values; + this.dims = dims; + this.packedDims = packedDims; + this.vectorPitch = vectorPitch; + this.correction = correction; + this.bulkCorrection = bulkCorrection; + } + + void checkOrdinal(int ord) { + if (ord < 0 || ord >= values.size()) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + private byte[] getScratch(int len) { + if (scratch == null || scratch.length < len) { + scratch = new byte[len]; + } + return scratch; + } + + private float applyCorrections(float rawScore, int ord, QueryContext query) throws IOException { + return correction.apply( + values, + dims, + rawScore, + ord, + query.lowerInterval(), + query.upperInterval(), + query.additionalCorrection(), + query.quantizedComponentSum() + ); + } + + private float applyCorrectionsBulk(MemorySegment scores, MemorySegment ordinals, int numNodes, QueryContext query) + throws IOException { + return bulkCorrection.apply( + values, + dims, + scores, + ordinals, + numNodes, + query.lowerInterval(), + query.upperInterval(), + query.additionalCorrection(), + query.quantizedComponentSum() + ); + } + + float scoreWithQuery(QueryContext query, int node) throws IOException { + checkOrdinal(node); + long nodeOffset = (long) node * vectorPitch; + input.seek(nodeOffset); + return IndexInputUtils.withSlice(input, packedDims, this::getScratch, packedTarget -> { + int rawScore = dotProductI4(query.unpackedQuery(), packedTarget, packedDims); + return applyCorrections(rawScore, node, query); + }); + } + + float bulkScoreWithQuery(QueryContext query, int[] ordinals, float[] scores, int numNodes) throws IOException { + input.seek(0); + return IndexInputUtils.withSlice(input, input.length(), this::getScratch, vectors -> { + if (SUPPORTS_HEAP_SEGMENTS) { + var ordinalsSeg = MemorySegment.ofArray(ordinals); + var scoresSeg = MemorySegment.ofArray(scores); + dotProductI4BulkWithOffsets( + vectors, + query.unpackedQuery(), + packedDims, + (int) vectorPitch, + ordinalsSeg, + numNodes, + scoresSeg + ); + return applyCorrectionsBulk(scoresSeg, ordinalsSeg, numNodes, query); + } else { + try (Arena arena = Arena.ofConfined()) { + MemorySegment ordinalsSeg = arena.allocate((long) numNodes * Integer.BYTES, Integer.BYTES); + MemorySegment scoresSeg = arena.allocate((long) numNodes * Float.BYTES, Float.BYTES); + MemorySegment.copy(ordinals, 0, ordinalsSeg, ValueLayout.JAVA_INT, 0, numNodes); + dotProductI4BulkWithOffsets( + vectors, + query.unpackedQuery(), + packedDims, + (int) vectorPitch, + ordinalsSeg, + numNodes, + scoresSeg + ); + float max = applyCorrectionsBulk(scoresSeg, ordinalsSeg, numNodes, query); + MemorySegment.copy(scoresSeg, ValueLayout.JAVA_FLOAT, 0, scores, 0, numNodes); + return max; + } + } + }); + } + } + + record QueryContext( + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum, + MemorySegment unpackedQuery + ) {} +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int4VectorScorerSupplier.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int4VectorScorerSupplier.java new file mode 100644 index 0000000000000..1f293385db0a2 --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int4VectorScorerSupplier.java @@ -0,0 +1,125 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.elasticsearch.simdvec.VectorSimilarityType; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +/** + * Int4 packed-nibble scorer supplier. + * Each stored vector is {@code dims/2} packed bytes (two 4-bit values per byte), followed by + * corrective terms (3 floats + 1 int). The query is unpacked to {@code dims} bytes before scoring. + */ +public final class Int4VectorScorerSupplier implements RandomVectorScorerSupplier { + + private final IndexInput input; + private final QuantizedByteVectorValues values; + private final VectorSimilarityType similarityType; + private final int packedDims; + private final long vectorPitch; + private final Int4VectorScorer.ScorerImpl scorerImpl; + private final MemorySegment unpackedQuerySegment; + + public Int4VectorScorerSupplier(IndexInput input, QuantizedByteVectorValues values, VectorSimilarityType similarityType) { + IndexInputUtils.checkInputType(input); + int dims = values.dimension(); + + this.input = input; + this.values = values; + this.similarityType = similarityType; + this.packedDims = dims / 2; + this.vectorPitch = packedDims + 3L * Float.BYTES + Integer.BYTES; + this.unpackedQuerySegment = Arena.ofAuto().allocate(dims, 32); + this.scorerImpl = new Int4VectorScorer.ScorerImpl( + input, + values, + dims, + packedDims, + vectorPitch, + Int4Corrections.singleCorrectionFor(similarityType), + Int4Corrections.bulkCorrectionFor(similarityType) + ); + } + + private Int4VectorScorer.QueryContext createQueryContext(int ord) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + long offset = (long) ord * vectorPitch; + input.seek(offset); + byte[] packed = new byte[packedDims]; + input.readBytes(packed, 0, packedDims); + unpackNibbles(packed); + return new Int4VectorScorer.QueryContext( + correctiveTerms.lowerInterval(), + correctiveTerms.upperInterval(), + correctiveTerms.additionalCorrection(), + correctiveTerms.quantizedComponentSum(), + unpackedQuerySegment + ); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new Int4VectorScorerSupplier(input.clone(), values.copy(), similarityType); + } + + @Override + public UpdateableRandomVectorScorer scorer() { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { + /** QueryContext instances used by this scorer are all backed by the same pre-allocated segment + * (see {@link Int4VectorScorerSupplier#createQueryContext}). + * The segment is reused across setScoringOrdinal calls; only the most recent one is valid. + * This makes this scorer and supplier not thread-safe. + */ + private Int4VectorScorer.QueryContext query; + + @Override + public float score(int node) throws IOException { + if (query == null) { + throw new IllegalStateException("scoring ordinal is not set"); + } + return scorerImpl.scoreWithQuery(query, node); + } + + @Override + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + if (query == null) { + throw new IllegalStateException("scoring ordinal is not set"); + } + return scorerImpl.bulkScoreWithQuery(query, nodes, scores, numNodes); + } + + @Override + public void setScoringOrdinal(int node) throws IOException { + scorerImpl.checkOrdinal(node); + query = createQueryContext(node); + } + }; + } + + public QuantizedByteVectorValues get() { + return values; + } + + private void unpackNibbles(byte[] packed) { + int packedLen = packed.length; + for (int i = 0; i < packedLen; i++) { + unpackedQuerySegment.setAtIndex(ValueLayout.JAVA_BYTE, i, (byte) ((packed[i] & 0xFF) >>> 4)); + unpackedQuerySegment.setAtIndex(ValueLayout.JAVA_BYTE, i + packedLen, (byte) (packed[i] & 0x0F)); + } + } +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java index 84bb7278facb6..4fb4bd5a04497 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java @@ -40,6 +40,13 @@ public class Similarities { Operation.BULK_OFFSETS ); + static final MethodHandle DOT_PRODUCT_I4 = DISTANCE_FUNCS.getHandle(Function.DOT_PRODUCT, DataType.INT4, Operation.SINGLE); + static final MethodHandle DOT_PRODUCT_I4_BULK = DISTANCE_FUNCS.getHandle(Function.DOT_PRODUCT, DataType.INT4, Operation.BULK); + static final MethodHandle DOT_PRODUCT_I4_BULK_WITH_OFFSETS = DISTANCE_FUNCS.getHandle( + Function.DOT_PRODUCT, + DataType.INT4, + Operation.BULK_OFFSETS + ); static final MethodHandle COSINE_I8 = DISTANCE_FUNCS.getHandle(Function.COSINE, DataType.INT8, Operation.SINGLE); static final MethodHandle COSINE_I8_BULK = DISTANCE_FUNCS.getHandle(Function.COSINE, DataType.INT8, Operation.BULK); static final MethodHandle COSINE_I8_BULK_WITH_OFFSETS = DISTANCE_FUNCS.getHandle( @@ -176,6 +183,38 @@ static void squareDistanceI7uBulkWithOffsets( } } + static int dotProductI4(MemorySegment unpacked, MemorySegment packed, int packedLen) { + try { + return (int) DOT_PRODUCT_I4.invokeExact(unpacked, packed, packedLen); + } catch (Throwable e) { + throw rethrow(e); + } + } + + static void dotProductI4Bulk(MemorySegment a, MemorySegment b, int packedLen, int count, MemorySegment scores) { + try { + DOT_PRODUCT_I4_BULK.invokeExact(a, b, packedLen, count, scores); + } catch (Throwable e) { + throw rethrow(e); + } + } + + static void dotProductI4BulkWithOffsets( + MemorySegment a, + MemorySegment b, + int packedLen, + int pitch, + MemorySegment offsets, + int count, + MemorySegment scores + ) { + try { + DOT_PRODUCT_I4_BULK_WITH_OFFSETS.invokeExact(a, b, packedLen, pitch, offsets, count, scores); + } catch (Throwable e) { + throw rethrow(e); + } + } + public static float cosineI8(MemorySegment a, MemorySegment b, int length) { try { return (float) COSINE_I8.invokeExact(a, b, length); diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java index c08afba09c2be..90f0a7f44b74e 100644 --- a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java @@ -42,7 +42,7 @@ public long int7DotProduct(byte[] q) throws IOException { private long nativeInt7DotProduct(byte[] q) throws IOException { return IndexInputUtils.withSlice(in, dimensions, this::getScratch, segment -> { final MemorySegment querySegment = MemorySegment.ofArray(q); - return Similarities.dotProductI7u(segment, querySegment, dimensions); + return (long) Similarities.dotProductI7u(segment, querySegment, dimensions); }); } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java index 259e5bafb765e..cd497e85c3952 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java @@ -16,7 +16,9 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.Arrays; import java.util.Optional; +import java.util.function.IntFunction; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; import static org.hamcrest.Matchers.not; @@ -25,6 +27,17 @@ public abstract class AbstractVectorTestCase extends ESTestCase { static Optional factory; + protected static final float DELTA = 1e-6f; + + /** + * Use a slightly larger delta for bulk scoring to account for floating point precision + * issues: applying the corrections in even a slightly different order can impact the score. + */ + protected static final float BULK_DELTA = 2e-5f; + + // Support for passing on-heap arrays/segments to native + protected static boolean SUPPORTS_HEAP_SEGMENTS = Runtime.version().feature() >= 22; + @BeforeClass public static void getVectorScorerFactory() { factory = org.elasticsearch.simdvec.VectorScorerFactory.instance(); @@ -61,11 +74,6 @@ public static String platformMsg() { return "JDK=" + jdkVersion + ", os=" + osName + ", arch=" + arch; } - // Support for passing on-heap arrays/segments to native - protected static boolean supportsHeapSegments() { - return Runtime.version().feature() >= 22; - } - /** Converts a float value to a byte array. */ public static byte[] floatToByteArray(float value) { return ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putFloat(value).array(); @@ -80,4 +88,18 @@ public static byte[] concat(byte[]... arrays) throws IOException { return baos.toByteArray(); } } + + static IntFunction FLOAT_ARRAY_RANDOM_FUNC = size -> { + float[] fa = new float[size]; + for (int i = 0; i < size; i++) { + fa[i] = randomFloat(); + } + return fa; + }; + + static IntFunction FLOAT_ARRAY_MAX_FUNC = size -> { + float[] fa = new float[size]; + Arrays.fill(fa, Float.MAX_VALUE); + return fa; + }; } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int4VectorScorerFactoryTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int4VectorScorerFactoryTests.java new file mode 100644 index 0000000000000..56b93497920da --- /dev/null +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int4VectorScorerFactoryTests.java @@ -0,0 +1,983 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import com.carrotsearch.randomizedtesting.generators.RandomNumbers; + +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorScorer; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat; +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.store.NIOFSDirectory; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.core.SuppressForbidden; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.function.IntFunction; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.elasticsearch.nativeaccess.Int4TestUtils.packNibbles; +import static org.elasticsearch.nativeaccess.Int4TestUtils.unpackNibbles; +import static org.elasticsearch.simdvec.VectorSimilarityType.DOT_PRODUCT; +import static org.elasticsearch.simdvec.VectorSimilarityType.EUCLIDEAN; +import static org.elasticsearch.simdvec.VectorSimilarityType.MAXIMUM_INNER_PRODUCT; +import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.createDenseInt4VectorValues; +import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.writePackedVectorWithCorrection; +import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; +import static org.hamcrest.Matchers.equalTo; + +public class Int4VectorScorerFactoryTests extends AbstractVectorTestCase { + private static final float LIMIT_SCALE = 1f / ((1 << 4) - 1); + + private final VectorSimilarityType similarityType; + + public Int4VectorScorerFactoryTests(VectorSimilarityType similarityType) { + this.similarityType = similarityType; + } + + @SuppressForbidden(reason = "require usage of OptimizedScalarQuantizer") + private static OptimizedScalarQuantizer scalarQuantizer(VectorSimilarityFunction sim) { + return new OptimizedScalarQuantizer(sim); + } + + // bounds of the range of values for int4 packed nibble (4-bit) + static final byte MIN_INT4_VALUE = 0; + static final byte MAX_INT4_VALUE = 0x0F; + + // Tests that the provider instance is present or not on expected platforms/architectures + public void testSupport() { + supported(); + } + + public void testSimple() throws IOException { + testSimpleImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE); + } + + public void testSimpleMaxChunkSizeSmall() throws IOException { + long maxChunkSize = randomLongBetween(4, 16); + logger.info("maxChunkSize=" + maxChunkSize); + testSimpleImpl(maxChunkSize); + } + + void testSimpleImpl(long maxChunkSize) throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testSimpleImpl"), maxChunkSize)) { + var scalarQuantizer = scalarQuantizer(similarityType.function()); + var encoding = Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.PACKED_NIBBLE; + for (int dims : List.of(30, 32, 34)) { + float[] query1 = new float[dims]; + float[] query2 = new float[dims]; + float[] centroid = new float[dims]; + float centroidDP = 0f; + byte[] scratch = new byte[encoding.getDiscreteDimensions(dims)]; + OptimizedScalarQuantizer.QuantizationResult vec1Correction, vec2Correction; + byte[] packed1, packed2; + String fileName = "testSimpleImpl-" + similarityType + "-" + dims + ".vex"; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < dims; i++) { + query1[i] = (float) i; + query2[i] = (float) (dims - i); + centroid[i] = (query1[i] + query2[i]) / 2f; + centroidDP += centroid[i] * centroid[i]; + } + vec1Correction = scalarQuantizer.scalarQuantize(query1, scratch, (byte) 4, centroid); + packed1 = packNibbles(Arrays.copyOf(scratch, dims)); + vec2Correction = scalarQuantizer.scalarQuantize(query2, scratch, (byte) 4, centroid); + packed2 = packNibbles(Arrays.copyOf(scratch, dims)); + writePackedVectorWithCorrection(out, packed1, vec1Correction); + writePackedVectorWithCorrection(out, packed2, vec2Correction); + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = createDenseInt4VectorValues(dims, 2, centroid, centroidDP, in, similarityType.function()); + float expected = luceneScore(similarityType, packed1, packed2, dims, centroidDP, vec1Correction, vec2Correction); + + var luceneSupplier = luceneScoreSupplier(values, similarityType.function()).scorer(); + luceneSupplier.setScoringOrdinal(1); + assertFloatEquals(expected, luceneSupplier.score(0), DELTA); + var supplier = factory.getInt4VectorScorerSupplier(similarityType, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(1); + assertFloatEquals(expected, scorer.score(0), DELTA); + + if (SUPPORTS_HEAP_SEGMENTS) { + byte[] unpackedQuery = unpackNibbles(packed2, dims); + var qScorer = factory.getInt4VectorScorer( + similarityType.function(), + values, + unpackedQuery, + vec2Correction.lowerInterval(), + vec2Correction.upperInterval(), + vec2Correction.additionalCorrection(), + vec2Correction.quantizedComponentSum() + ).get(); + assertFloatEquals(expected, qScorer.score(0), DELTA); + } + } + } + } + } + + public void testRandomMMap() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + try (Directory dir = new MMapDirectory(createTempDir("testRandomMMap"))) { + testRandomSupplier(dir, BYTE_ARRAY_RANDOM_INT4_FUNC); + } + } + + public void testRandomNIO() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + try (Directory dir = new NIOFSDirectory(createTempDir("testRandomNIO"))) { + testRandomSupplier(dir, BYTE_ARRAY_RANDOM_INT4_FUNC); + } + } + + public void testRandomMaxChunkSizeSmall() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + long maxChunkSize = randomLongBetween(32, 128); + logger.info("maxChunkSize=" + maxChunkSize); + try (Directory dir = new MMapDirectory(createTempDir("testRandomMaxChunkSizeSmall"), maxChunkSize)) { + testRandomSupplier(dir, BYTE_ARRAY_RANDOM_INT4_FUNC); + } + } + + public void testRandomMax() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + try (Directory dir = new MMapDirectory(createTempDir("testRandomMax"))) { + testRandomSupplier(dir, BYTE_ARRAY_MAX_INT4_FUNC); + } + } + + public void testRandomMin() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + try (Directory dir = new MMapDirectory(createTempDir("testRandomMin"))) { + testRandomSupplier(dir, BYTE_ARRAY_MIN_INT4_FUNC); + } + } + + void testRandomSupplier(Directory dir, IntFunction packedByteArraySupplier) throws IOException { + var factory = AbstractVectorTestCase.factory.get(); + + final int dims = randomIntBetween(1, 2048) * 2; + final int size = randomIntBetween(2, 100); + final byte[][] packedVectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] quantizationResults = new OptimizedScalarQuantizer.QuantizationResult[size]; + final float[] centroid = new float[dims]; + + String fileName = "testRandom-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var packed = packedByteArraySupplier.apply(dims); + int componentSum = componentSumUnpacked(packed); + float lowerInterval = randomFloat(); + float upperInterval = randomFloat() + lowerInterval; + quantizationResults[i] = new OptimizedScalarQuantizer.QuantizationResult( + lowerInterval, + upperInterval, + randomFloat(), + componentSum + ); + writePackedVectorWithCorrection(out, packed, quantizationResults[i]); + packedVectors[i] = packed; + } + } + for (int i = 0; i < dims; i++) { + centroid[i] = randomFloat(); + } + float centroidDP = VectorUtil.dotProduct(centroid, centroid); + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); + var values = createDenseInt4VectorValues(dims, size, centroid, centroidDP, in, similarityType.function()); + float expected = luceneScore( + similarityType, + packedVectors[idx0], + packedVectors[idx1], + dims, + centroidDP, + quantizationResults[idx0], + quantizationResults[idx1] + ); + var supplier = factory.getInt4VectorScorerSupplier(similarityType, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx1); + assertFloatEquals(expected, scorer.score(idx0), DELTA); + } + } + } + + public void testRandomScorerMMap() throws IOException { + try (Directory dir = new MMapDirectory(createTempDir("testRandomScorerMMap"))) { + testRandomScorerImpl(dir, FLOAT_ARRAY_RANDOM_FUNC); + } + } + + public void testRandomScorerNIO() throws IOException { + try (Directory dir = new NIOFSDirectory(createTempDir("testRandomScorerNIO"))) { + testRandomScorerImpl(dir, FLOAT_ARRAY_RANDOM_FUNC); + } + } + + public void testRandomScorerMax() throws IOException { + try (Directory dir = new MMapDirectory(createTempDir("testRandomScorerMax"))) { + testRandomScorerImpl(dir, FLOAT_ARRAY_MAX_FUNC); + } + } + + public void testRandomScorerChunkSizeSmall() throws IOException { + long maxChunkSize = randomLongBetween(32, 128); + logger.info("maxChunkSize=" + maxChunkSize); + try (Directory dir = new MMapDirectory(createTempDir("testRandomScorerChunkSizeSmall"), maxChunkSize)) { + testRandomScorerImpl(dir, FLOAT_ARRAY_RANDOM_FUNC); + } + } + + void testRandomScorerImpl(Directory dir, IntFunction floatArraySupplier) throws IOException { + assumeTrue("scorer only supported on JDK 22+", SUPPORTS_HEAP_SEGMENTS); + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + var scalarQuantizer = scalarQuantizer(similarityType.function()); + var encoding = Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.PACKED_NIBBLE; + final int dims = randomIntBetween(1, 2048) * 2; + final int size = randomIntBetween(2, 100); + final float[] centroid = new float[dims]; + for (int i = 0; i < dims; i++) { + centroid[i] = randomFloat(); + } + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final float[][] vectors = new float[size][]; + final byte[][] packedVectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + byte[] scratch = new byte[encoding.getDiscreteDimensions(dims)]; + + String fileName = "testRandom-" + similarityType + "-" + dims + ".vex"; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + vectors[i] = floatArraySupplier.apply(dims); + corrections[i] = scalarQuantizer.scalarQuantize(vectors[i], scratch, (byte) 4, centroid); + packedVectors[i] = packNibbles(Arrays.copyOf(scratch, dims)); + writePackedVectorWithCorrection(out, packedVectors[i], corrections[i]); + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); + var values = createDenseInt4VectorValues(dims, size, centroid, centroidDP, in, similarityType.function()); + + var expected = luceneScore( + similarityType, + packedVectors[idx0], + packedVectors[idx1], + dims, + centroidDP, + corrections[idx0], + corrections[idx1] + ); + byte[] unpackedQuery = unpackNibbles(packedVectors[idx0], dims); + var scorer = factory.getInt4VectorScorer( + similarityType.function(), + values, + unpackedQuery, + corrections[idx0].lowerInterval(), + corrections[idx0].upperInterval(), + corrections[idx0].additionalCorrection(), + corrections[idx0].quantizedComponentSum() + ).get(); + assertFloatEquals(expected, scorer.score(idx1), DELTA); + } + } + } + + public void testRandomSlice() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + testRandomSliceImpl(30, 64, 1, BYTE_ARRAY_RANDOM_INT4_FUNC); + } + + void testRandomSliceImpl(int dims, long maxChunkSize, int initialPadding, IntFunction packedByteArraySupplier) + throws IOException { + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testRandomSliceImpl"), maxChunkSize)) { + for (int times = 0; times < TIMES; times++) { + final int size = randomIntBetween(2, 100); + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final byte[][] packedVectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + + String fileName = "testRandomSliceImpl-" + times + "-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + byte[] ba = new byte[initialPadding]; + out.writeBytes(ba, 0, ba.length); + for (int i = 0; i < size; i++) { + var packed = packedByteArraySupplier.apply(dims); + var correction = randomCorrectionPacked(packed); + writePackedVectorWithCorrection(out, packed, correction); + packedVectors[i] = packed; + corrections[i] = correction; + } + } + try ( + var outter = dir.openInput(fileName, IOContext.DEFAULT); + var in = outter.slice("slice", initialPadding, outter.length() - initialPadding) + ) { + for (int itrs = 0; itrs < TIMES / 10; itrs++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); + var values = createDenseInt4VectorValues(dims, size, centroid, centroidDP, in, similarityType.function()); + float expected = luceneScore( + similarityType, + packedVectors[idx0], + packedVectors[idx1], + dims, + centroidDP, + corrections[idx0], + corrections[idx1] + ); + var supplier = factory.getInt4VectorScorerSupplier(similarityType, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx1); + assertFloatEquals(expected, scorer.score(idx0), DELTA); + } + } + } + } + } + + @Nightly + public void testLarge() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testLarge"))) { + final int dims = 8192; + final int size = 262144; + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + + String fileName = "testLarge-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var packed = vector(i, dims); + var correction = randomCorrectionPacked(packed); + writePackedVectorWithCorrection(out, packed, correction); + corrections[i] = correction; + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = size - 1; + var values = createDenseInt4VectorValues(dims, size, centroid, centroidDP, in, similarityType.function()); + float expected = luceneScore( + similarityType, + vector(idx0, dims), + vector(idx1, dims), + dims, + centroidDP, + corrections[idx0], + corrections[idx1] + ); + var supplier = factory.getInt4VectorScorerSupplier(similarityType, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx1); + assertFloatEquals(expected, scorer.score(idx0), DELTA); + } + } + } + } + + public void testDatasetGreaterThanChunkSize() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testDatasetGreaterThanChunkSize"), 8192)) { + final int dims = 1024; + final int size = 128; + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final byte[][] packedVectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + + String fileName = "testDatasetGreaterThanChunkSize-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var packed = vector(i, dims); + var correction = randomCorrectionPacked(packed); + writePackedVectorWithCorrection(out, packed, correction); + packedVectors[i] = packed; + corrections[i] = correction; + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = size - 1; + var values = createDenseInt4VectorValues(dims, size, centroid, centroidDP, in, similarityType.function()); + float expected = luceneScore( + similarityType, + packedVectors[idx0], + packedVectors[idx1], + dims, + centroidDP, + corrections[idx0], + corrections[idx1] + ); + var supplier = factory.getInt4VectorScorerSupplier(similarityType, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx1); + assertFloatEquals(expected, scorer.score(idx0), DELTA); + } + } + } + } + + public void testBulkMMap() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + try (Directory dir = new MMapDirectory(createTempDir("testBulkMMap"))) { + testBulkImpl(dir); + } + } + + public void testBulkNIO() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + try (Directory dir = new NIOFSDirectory(createTempDir("testBulkNIO"))) { + testBulkImpl(dir); + } + } + + void testBulkImpl(Directory dir) throws IOException { + var factory = AbstractVectorTestCase.factory.get(); + + final int dims = 1024; + final int size = randomIntBetween(1, 102); + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final byte[][] packedVectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + String fileName = "testBulk-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var packed = vector(i, dims); + var correction = randomCorrectionPacked(packed); + writePackedVectorWithCorrection(out, packed, correction); + packedVectors[i] = packed; + corrections[i] = correction; + } + } + + List ids = IntStream.range(0, size).boxed().collect(Collectors.toList()); + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int[] nodes = shuffledList(ids).stream().mapToInt(i -> i).toArray(); + QuantizedByteVectorValues values = createDenseInt4VectorValues( + dims, + size, + centroid, + centroidDP, + in, + similarityType.function() + ); + float[] expected = new float[nodes.length]; + float[] scores = new float[nodes.length]; + var referenceScorer = luceneScoreSupplier(values, similarityType.function()).scorer(); + referenceScorer.setScoringOrdinal(idx0); + referenceScorer.bulkScore(nodes, expected, nodes.length); + var supplier = factory.getInt4VectorScorerSupplier(similarityType, in, values).orElseThrow(); + var testScorer = supplier.scorer(); + testScorer.setScoringOrdinal(idx0); + testScorer.bulkScore(nodes, scores, nodes.length); + assertFloatArrayEquals(expected, scores, BULK_DELTA); + } + } + } + + public void testBulkWithDatasetGreaterThanChunkSize() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + final int dims = 1024; + final int size = 128; + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final byte[][] packedVectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + try (Directory dir = new MMapDirectory(createTempDir("testBulkWithDatasetGreaterThanChunkSize"), 8192)) { + String fileName = "testBulkWithDatasetGreaterThanChunkSize-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var packed = vector(i, dims); + var correction = randomCorrectionPacked(packed); + writePackedVectorWithCorrection(out, packed, correction); + packedVectors[i] = packed; + corrections[i] = correction; + } + } + + List ids = IntStream.range(0, size).boxed().collect(Collectors.toList()); + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int[] nodes = shuffledList(ids).stream().mapToInt(i -> i).toArray(); + QuantizedByteVectorValues values = createDenseInt4VectorValues( + dims, + size, + centroid, + centroidDP, + in, + similarityType.function() + ); + float[] expected = new float[nodes.length]; + float[] scores = new float[nodes.length]; + var referenceScorer = luceneScoreSupplier(values, similarityType.function()).scorer(); + referenceScorer.setScoringOrdinal(idx0); + referenceScorer.bulkScore(nodes, expected, nodes.length); + var supplier = factory.getInt4VectorScorerSupplier(similarityType, in, values).orElseThrow(); + var testScorer = supplier.scorer(); + testScorer.setScoringOrdinal(idx0); + testScorer.bulkScore(nodes, scores, nodes.length); + assertFloatArrayEquals(expected, scores, BULK_DELTA); + } + } + } + } + + public void testBulkScorerMMap() throws IOException { + assumeTrue("scorer only supported on JDK 22+", SUPPORTS_HEAP_SEGMENTS); + assumeTrue(notSupportedMsg(), supported()); + try (Directory dir = new MMapDirectory(createTempDir("testBulkScorerMMap"))) { + testBulkScorerImpl(dir); + } + } + + public void testBulkScorerNIO() throws IOException { + assumeTrue("scorer only supported on JDK 22+", SUPPORTS_HEAP_SEGMENTS); + assumeTrue(notSupportedMsg(), supported()); + try (Directory dir = new NIOFSDirectory(createTempDir("testBulkScorerNIO"))) { + testBulkScorerImpl(dir); + } + } + + void testBulkScorerImpl(Directory dir) throws IOException { + var factory = AbstractVectorTestCase.factory.get(); + + final int dims = 1024; + final int size = randomIntBetween(2, 100); + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final byte[][] packedVectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + String fileName = "testBulkScorer-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var packed = vector(i, dims); + var correction = randomCorrectionPacked(packed); + writePackedVectorWithCorrection(out, packed, correction); + packedVectors[i] = packed; + corrections[i] = correction; + } + } + + List ids = IntStream.range(0, size).boxed().collect(Collectors.toList()); + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int queryIdx = randomIntBetween(0, size - 1); + int[] nodes = shuffledList(ids).stream().mapToInt(i -> i).toArray(); + QuantizedByteVectorValues values = createDenseInt4VectorValues( + dims, + size, + centroid, + centroidDP, + in, + similarityType.function() + ); + float[] expected = new float[nodes.length]; + float[] scores = new float[nodes.length]; + var referenceScorer = luceneScoreSupplier(values, similarityType.function()).scorer(); + referenceScorer.setScoringOrdinal(queryIdx); + referenceScorer.bulkScore(nodes, expected, nodes.length); + + byte[] unpackedQuery = unpackNibbles(packedVectors[queryIdx], dims); + var scorer = factory.getInt4VectorScorer( + similarityType.function(), + values, + unpackedQuery, + corrections[queryIdx].lowerInterval(), + corrections[queryIdx].upperInterval(), + corrections[queryIdx].additionalCorrection(), + corrections[queryIdx].quantizedComponentSum() + ).get(); + scorer.bulkScore(nodes, scores, nodes.length); + assertFloatArrayEquals(expected, scores, BULK_DELTA); + } + } + } + + public void testScorerSupplierSequentialOrdinals() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + final int dims = 128; + final int size = 10; + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final byte[][] packedVectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + try (Directory dir = new MMapDirectory(createTempDir("testSequentialOrdinals"))) { + String fileName = "testSequentialOrdinals-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var packed = vector(i, dims); + var correction = randomCorrectionPacked(packed); + writePackedVectorWithCorrection(out, packed, correction); + packedVectors[i] = packed; + corrections[i] = correction; + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = createDenseInt4VectorValues(dims, size, centroid, centroidDP, in, similarityType.function()); + var supplier = factory.getInt4VectorScorerSupplier(similarityType, in, values).get(); + var scorer = supplier.scorer(); + for (int queryOrd = 0; queryOrd < size; queryOrd++) { + scorer.setScoringOrdinal(queryOrd); + for (int targetOrd = 0; targetOrd < size; targetOrd++) { + float expected = luceneScore( + similarityType, + packedVectors[queryOrd], + packedVectors[targetOrd], + dims, + centroidDP, + corrections[queryOrd], + corrections[targetOrd] + ); + assertFloatEquals(expected, scorer.score(targetOrd), BULK_DELTA); + } + } + } + } + } + + public void testInvalidOrdinal() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + final int dims = 32; + final int size = 2; + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + try (Directory dir = new MMapDirectory(createTempDir("testInvalidOrdinal"))) { + String fileName = "testInvalidOrdinal-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var packed = vector(i, dims); + writePackedVectorWithCorrection(out, packed, randomCorrectionPacked(packed)); + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = createDenseInt4VectorValues(dims, size, centroid, centroidDP, in, similarityType.function()); + var supplier = factory.getInt4VectorScorerSupplier(similarityType, in, values).get(); + var scorer = supplier.scorer(); + expectThrows(IllegalArgumentException.class, () -> scorer.setScoringOrdinal(-1)); + expectThrows(IllegalArgumentException.class, () -> scorer.setScoringOrdinal(size)); + + if (SUPPORTS_HEAP_SEGMENTS) { + byte[] packed0 = vector(0, dims); + byte[] unpackedQuery = unpackNibbles(packed0, dims); + var correction = randomCorrectionPacked(packed0); + var qScorer = factory.getInt4VectorScorer( + similarityType.function(), + values, + unpackedQuery, + correction.lowerInterval(), + correction.upperInterval(), + correction.additionalCorrection(), + correction.quantizedComponentSum() + ).get(); + expectThrows(IllegalArgumentException.class, () -> qScorer.score(-1)); + expectThrows(IllegalArgumentException.class, () -> qScorer.score(size)); + } + } + } + } + + public void testScoreBeforeSetOrdinal() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + final int dims = 32; + final int size = 2; + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + try (Directory dir = new MMapDirectory(createTempDir("testScoreBeforeSetOrdinal"))) { + String fileName = "testScoreBeforeSetOrdinal-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var packed = vector(i, dims); + writePackedVectorWithCorrection(out, packed, randomCorrectionPacked(packed)); + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = createDenseInt4VectorValues(dims, size, centroid, centroidDP, in, similarityType.function()); + var supplier = factory.getInt4VectorScorerSupplier(similarityType, in, values).get(); + var scorer = supplier.scorer(); + expectThrows(IllegalStateException.class, () -> scorer.score(0)); + } + } + } + + public void testRace() throws Exception { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + final long maxChunkSize = 32; + final int dims = 34; + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + byte[] unpacked1 = new byte[dims]; + byte[] unpacked2 = new byte[dims]; + IntStream.range(0, dims).forEach(i -> unpacked1[i] = 1); + IntStream.range(0, dims).forEach(i -> unpacked2[i] = 2); + byte[] packed1 = packNibbles(unpacked1); + byte[] packed2 = packNibbles(unpacked2); + var correction1 = randomCorrectionPacked(packed1); + var correction2 = randomCorrectionPacked(packed2); + try (Directory dir = new MMapDirectory(createTempDir("testRace"), maxChunkSize)) { + String fileName = "testRace-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + writePackedVectorWithCorrection(out, packed1, correction1); + writePackedVectorWithCorrection(out, packed1, correction1); + writePackedVectorWithCorrection(out, packed2, correction2); + writePackedVectorWithCorrection(out, packed2, correction2); + } + var expectedScore1 = luceneScore(similarityType, packed1, packed1, dims, centroidDP, correction1, correction1); + var expectedScore2 = luceneScore(similarityType, packed2, packed2, dims, centroidDP, correction2, correction2); + + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = createDenseInt4VectorValues(dims, 4, centroid, centroidDP, in, similarityType.function()); + var scoreSupplier = factory.getInt4VectorScorerSupplier(similarityType, in, values).get(); + var tasks = List.>>of( + new ScoreCallable(scoreSupplier.copy().scorer(), 0, 1, expectedScore1), + new ScoreCallable(scoreSupplier.copy().scorer(), 2, 3, expectedScore2) + ); + var executor = Executors.newFixedThreadPool(2); + var results = executor.invokeAll(tasks); + executor.shutdown(); + assertTrue(executor.awaitTermination(60, TimeUnit.SECONDS)); + assertThat(results.stream().filter(Predicate.not(Future::isDone)).count(), equalTo(0L)); + for (var res : results) { + assertThat("Unexpected exception" + res.get(), res.get(), isEmpty()); + } + } + } + } + + static class ScoreCallable implements Callable> { + + final UpdateableRandomVectorScorer scorer; + final int ord; + final float expectedScore; + + ScoreCallable(UpdateableRandomVectorScorer scorer, int queryOrd, int ord, float expectedScore) { + try { + this.scorer = scorer; + this.scorer.setScoringOrdinal(queryOrd); + this.ord = ord; + this.expectedScore = expectedScore; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Optional call() { + try { + for (int i = 0; i < 100; i++) { + assertFloatEquals(expectedScore, scorer.score(ord), DELTA); + } + } catch (Throwable t) { + return Optional.of(t); + } + return Optional.empty(); + } + } + + static int componentSumUnpacked(byte[] packed) { + byte[] unpacked = unpackNibbles(packed, packed.length * 2); + int sum = 0; + for (byte value : unpacked) { + sum += Byte.toUnsignedInt(value) & 0x0F; + } + return sum; + } + + private static OptimizedScalarQuantizer.QuantizationResult randomCorrectionPacked(byte[] packed) { + int componentSum = componentSumUnpacked(packed); + float lowerInterval = randomFloat(); + float upperInterval = lowerInterval + randomFloat(); + return new OptimizedScalarQuantizer.QuantizationResult(lowerInterval, upperInterval, randomFloat(), componentSum); + } + + public float luceneScore( + VectorSimilarityType similarityFunc, + byte[] packedA, + byte[] packedB, + int dims, + float centroidDP, + OptimizedScalarQuantizer.QuantizationResult aCorrection, + OptimizedScalarQuantizer.QuantizationResult bCorrection + ) { + OSQScorer scorer = OSQScorer.fromSimilarity(similarityFunc); + return scorer.score(packedA, packedB, dims, centroidDP, aCorrection, bCorrection); + } + + private abstract static class OSQScorer { + static OSQScorer fromSimilarity(VectorSimilarityType sim) { + return switch (sim) { + case DOT_PRODUCT -> new DotProductOSQScorer(); + case MAXIMUM_INNER_PRODUCT -> new MaxInnerProductOSQScorer(); + case EUCLIDEAN -> new EuclideanOSQScorer(); + default -> throw new IllegalArgumentException("Unsupported similarity: " + sim); + }; + } + + final float score( + byte[] packedA, + byte[] packedB, + int dims, + float centroidDP, + OptimizedScalarQuantizer.QuantizationResult aCorrection, + OptimizedScalarQuantizer.QuantizationResult bCorrection + ) { + byte[] unpackedB = unpackNibbles(packedB, dims); + float rawDot = VectorUtil.int4DotProductSinglePacked(unpackedB, packedA); + float ax = aCorrection.lowerInterval(); + float lx = (aCorrection.upperInterval() - ax) * LIMIT_SCALE; + float ay = bCorrection.lowerInterval(); + float ly = (bCorrection.upperInterval() - ay) * LIMIT_SCALE; + float y1 = bCorrection.quantizedComponentSum(); + float x1 = aCorrection.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * rawDot; + return scaleScore(score, aCorrection.additionalCorrection(), bCorrection.additionalCorrection(), centroidDP); + } + + abstract float scaleScore(float score, float aCorrection, float bCorrection, float centroidDP); + + private static class DotProductOSQScorer extends OSQScorer { + @Override + float scaleScore(float score, float aCorrection, float bCorrection, float centroidDP) { + score += aCorrection + bCorrection - centroidDP; + score = Math.clamp(score, -1, 1); + return VectorUtil.normalizeToUnitInterval(score); + } + } + + private static class MaxInnerProductOSQScorer extends OSQScorer { + @Override + float scaleScore(float score, float aCorrection, float bCorrection, float centroidDP) { + score += aCorrection + bCorrection - centroidDP; + return VectorUtil.scaleMaxInnerProductScore(score); + } + } + + private static class EuclideanOSQScorer extends OSQScorer { + @Override + float scaleScore(float score, float aCorrection, float bCorrection, float centroidDP) { + score = aCorrection + bCorrection - 2 * score; + return VectorUtil.normalizeDistanceToUnitInterval(Math.max(score, 0f)); + } + } + } + + static void assertFloatArrayEquals(float[] expected, float[] actual, float delta) { + assertThat(actual.length, equalTo(expected.length)); + for (int i = 0; i < expected.length; i++) { + assertEquals("differed at element [" + i + "]", expected[i], actual[i], Math.abs(expected[i]) * delta + delta); + } + } + + static void assertFloatEquals(float expected, float actual, float delta) { + assertEquals(expected, actual, Math.abs(expected) * delta + delta); + } + + static RandomVectorScorerSupplier luceneScoreSupplier(QuantizedByteVectorValues values, VectorSimilarityFunction sim) + throws IOException { + return new Lucene104ScalarQuantizedVectorScorer(null).getRandomVectorScorerSupplier(sim, values); + } + + static byte[] vector(int ord, int dims) { + var random = new Random(Objects.hash(ord, dims)); + byte[] unpacked = new byte[dims]; + for (int i = 0; i < dims; i++) { + unpacked[i] = (byte) RandomNumbers.randomIntBetween(random, MIN_INT4_VALUE, MAX_INT4_VALUE); + } + return packNibbles(unpacked); + } + + static IntFunction BYTE_ARRAY_RANDOM_INT4_FUNC = dims -> { + byte[] unpacked = new byte[dims]; + for (int i = 0; i < dims; i++) { + unpacked[i] = (byte) randomIntBetween(MIN_INT4_VALUE, MAX_INT4_VALUE); + } + return packNibbles(unpacked); + }; + + static IntFunction BYTE_ARRAY_MAX_INT4_FUNC = dims -> { + byte[] unpacked = new byte[dims]; + Arrays.fill(unpacked, MAX_INT4_VALUE); + return packNibbles(unpacked); + }; + + static IntFunction BYTE_ARRAY_MIN_INT4_FUNC = dims -> { + byte[] unpacked = new byte[dims]; + Arrays.fill(unpacked, MIN_INT4_VALUE); + return packNibbles(unpacked); + }; + + static final int TIMES = 100; + + @ParametersFactory + public static Iterable parametersFactory() { + return List.of(new Object[] { DOT_PRODUCT }, new Object[] { EUCLIDEAN }, new Object[] { MAXIMUM_INNER_PRODUCT }); + } +} diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7SQVectorScorerFactoryTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7SQVectorScorerFactoryTests.java index 790657e7740c8..953e27bdc9f93 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7SQVectorScorerFactoryTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7SQVectorScorerFactoryTests.java @@ -108,7 +108,7 @@ void testSimpleImpl(long maxChunkSize) throws IOException { scorer.setScoringOrdinal(0); assertThat(scorer.score(1), equalTo(expected)); - if (supportsHeapSegments()) { + if (SUPPORTS_HEAP_SEGMENTS) { var qScorer = factory.getInt7SQVectorScorer(sim.function(), values, query1).get(); assertThat(qScorer.score(1), equalTo(expected)); } @@ -230,11 +230,11 @@ void testRandomSupplier(long maxChunkSize, IntFunction byteArraySupplier } public void testRandomScorer() throws IOException { - testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, Int7SQVectorScorerFactoryTests.FLOAT_ARRAY_RANDOM_FUNC); + testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, FLOAT_ARRAY_RANDOM_FUNC); } public void testRandomScorerMax() throws IOException { - testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, Int7SQVectorScorerFactoryTests.FLOAT_ARRAY_MAX_FUNC); + testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, FLOAT_ARRAY_MAX_FUNC); } public void testRandomScorerChunkSizeSmall() throws IOException { @@ -244,7 +244,7 @@ public void testRandomScorerChunkSizeSmall() throws IOException { } void testRandomScorerImpl(long maxChunkSize, IntFunction floatArraySupplier) throws IOException { - assumeTrue("scorer only supported on JDK 22+", Runtime.version().feature() >= 22); + assumeTrue("scorer only supported on JDK 22+", SUPPORTS_HEAP_SEGMENTS); assumeTrue(notSupportedMsg(), supported()); var factory = AbstractVectorTestCase.factory.get(); var scalarQuantizer = new ScalarQuantizer(0.1f, 0.9f, (byte) 7); @@ -259,7 +259,7 @@ void testRandomScorerImpl(long maxChunkSize, IntFunction floatArraySupp final byte[][] qVectors = new byte[size][]; final float[] corrections = new float[size]; - float delta = 1e-6f * dims; + float delta = DELTA * dims; String fileName = "testRandom-" + sim + "-" + dims + ".vex"; logger.info("Testing " + fileName); @@ -457,7 +457,7 @@ public void testBulk() throws IOException { var testScorer = supplier.scorer(); testScorer.setScoringOrdinal(idx0); testScorer.bulkScore(nodes, scores, nodes.length); - assertArrayEquals(expected, scores, 1e-6f); + assertArrayEquals(expected, scores, DELTA); } } } @@ -507,7 +507,7 @@ public void testBulkWithDatasetGreaterThanChunkSize() throws IOException { var testScorer = supplier.scorer(); testScorer.setScoringOrdinal(idx0); testScorer.bulkScore(nodes, scores, nodes.length); - assertArrayEquals(expected, scores, 1e-6f); + assertArrayEquals(expected, scores, DELTA); } } } @@ -625,20 +625,6 @@ static byte[] vector(int ord, int dims) { return ba; } - static IntFunction FLOAT_ARRAY_RANDOM_FUNC = size -> { - float[] fa = new float[size]; - for (int i = 0; i < size; i++) { - fa[i] = randomFloat(); - } - return fa; - }; - - static IntFunction FLOAT_ARRAY_MAX_FUNC = size -> { - float[] fa = new float[size]; - Arrays.fill(fa, Float.MAX_VALUE); - return fa; - }; - static IntFunction BYTE_ARRAY_RANDOM_INT7_FUNC = size -> { byte[] ba = new byte[size]; randomBytesBetween(ba, MIN_INT7_VALUE, MAX_INT7_VALUE); diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7uOSQVectorScorerFactoryTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7uOSQVectorScorerFactoryTests.java index 04fb45f1f22dd..8b4ae8640a1d8 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7uOSQVectorScorerFactoryTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7uOSQVectorScorerFactoryTests.java @@ -119,13 +119,13 @@ void testSimpleImpl(long maxChunkSize) throws IOException { var luceneSupplier = luceneScoreSupplier(values, sim.function()).scorer(); luceneSupplier.setScoringOrdinal(1); - assertFloatEquals(expected, luceneSupplier.score(0), 1e-6f); + assertFloatEquals(expected, luceneSupplier.score(0), DELTA); var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); var scorer = supplier.scorer(); scorer.setScoringOrdinal(1); - assertFloatEquals(expected, scorer.score(0), 1e-6f); + assertFloatEquals(expected, scorer.score(0), DELTA); - if (supportsHeapSegments()) { + if (SUPPORTS_HEAP_SEGMENTS) { var qScorer = factory.getInt7uOSQVectorScorer( sim.function(), values, @@ -135,7 +135,7 @@ void testSimpleImpl(long maxChunkSize) throws IOException { vec2Correction.additionalCorrection(), vec2Correction.quantizedComponentSum() ).get(); - assertFloatEquals(expected, qScorer.score(0), 1e-6f); + assertFloatEquals(expected, qScorer.score(0), DELTA); } } } @@ -221,7 +221,7 @@ void testRandomSupplier(long maxChunkSize, IntFunction byteArraySupplier var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); var scorer = supplier.scorer(); scorer.setScoringOrdinal(idx1); - assertFloatEquals(expected, scorer.score(idx0), 1e-6f); + assertFloatEquals(expected, scorer.score(idx0), DELTA); } } } @@ -229,17 +229,11 @@ void testRandomSupplier(long maxChunkSize, IntFunction byteArraySupplier } public void testRandomScorer() throws IOException { - testRandomScorerImpl( - MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, - org.elasticsearch.simdvec.Int7SQVectorScorerFactoryTests.FLOAT_ARRAY_RANDOM_FUNC - ); + testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, FLOAT_ARRAY_RANDOM_FUNC); } public void testRandomScorerMax() throws IOException { - testRandomScorerImpl( - MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, - org.elasticsearch.simdvec.Int7SQVectorScorerFactoryTests.FLOAT_ARRAY_MAX_FUNC - ); + testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, FLOAT_ARRAY_MAX_FUNC); } public void testRandomScorerChunkSizeSmall() throws IOException { @@ -249,7 +243,7 @@ public void testRandomScorerChunkSizeSmall() throws IOException { } void testRandomScorerImpl(long maxChunkSize, IntFunction floatArraySupplier) throws IOException { - assumeTrue("scorer only supported on JDK 22+", Runtime.version().feature() >= 22); + assumeTrue("scorer only supported on JDK 22+", SUPPORTS_HEAP_SEGMENTS); assumeTrue(notSupportedMsg(), supported()); var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); @@ -297,7 +291,7 @@ void testRandomScorerImpl(long maxChunkSize, IntFunction floatArraySupp corrections[idx0].additionalCorrection(), corrections[idx0].quantizedComponentSum() ).get(); - assertFloatEquals(expected, scorer.score(idx1), 1e-6f); + assertFloatEquals(expected, scorer.score(idx1), DELTA); } } } @@ -353,7 +347,7 @@ void testRandomSliceImpl(int dims, long maxChunkSize, int initialPadding, IntFun var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); var scorer = supplier.scorer(); scorer.setScoringOrdinal(idx1); - assertFloatEquals(expected, scorer.score(idx0), 1e-6f); + assertFloatEquals(expected, scorer.score(idx0), DELTA); } } } @@ -401,7 +395,7 @@ public void testLarge() throws IOException { var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); var scorer = supplier.scorer(); scorer.setScoringOrdinal(idx1); - assertFloatEquals(expected, scorer.score(idx0), 1e-6f); + assertFloatEquals(expected, scorer.score(idx0), DELTA); } } } @@ -442,7 +436,7 @@ public void testDatasetGreaterThanChunkSize() throws IOException { var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); var scorer = supplier.scorer(); scorer.setScoringOrdinal(idx1); - assertFloatEquals(expected, scorer.score(idx0), 1e-6f); + assertFloatEquals(expected, scorer.score(idx0), DELTA); } } } @@ -489,9 +483,7 @@ public void testBulk() throws IOException { var testScorer = supplier.scorer(); testScorer.setScoringOrdinal(idx0); testScorer.bulkScore(nodes, scores, nodes.length); - // applying the corrections in even a slightly different order can impact the score - // account for this during bulk scoring - assertFloatArrayEquals(expected, scores, 2e-5f); + assertFloatArrayEquals(expected, scores, BULK_DELTA); } } } @@ -538,7 +530,7 @@ public void testBulkWithDatasetGreaterThanChunkSize() throws IOException { var testScorer = supplier.scorer(); testScorer.setScoringOrdinal(idx0); testScorer.bulkScore(nodes, scores, nodes.length); - assertFloatArrayEquals(expected, scores, 1e-6f); + assertFloatArrayEquals(expected, scores, DELTA); } } } @@ -617,7 +609,7 @@ static class ScoreCallable implements Callable> { public Optional call() { try { for (int i = 0; i < 100; i++) { - assertFloatEquals(expectedScore, scorer.score(ord), 1e-6f); + assertFloatEquals(expectedScore, scorer.score(ord), DELTA); } } catch (Throwable t) { return Optional.of(t); diff --git a/libs/simdvec/src/testFixtures/java/org/elasticsearch/simdvec/internal/vectorization/VectorScorerTestUtils.java b/libs/simdvec/src/testFixtures/java/org/elasticsearch/simdvec/internal/vectorization/VectorScorerTestUtils.java index 5b6b52a102c0a..e1b07ab803cf5 100644 --- a/libs/simdvec/src/testFixtures/java/org/elasticsearch/simdvec/internal/vectorization/VectorScorerTestUtils.java +++ b/libs/simdvec/src/testFixtures/java/org/elasticsearch/simdvec/internal/vectorization/VectorScorerTestUtils.java @@ -9,15 +9,21 @@ package org.elasticsearch.simdvec.internal.vectorization; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat; +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.index.codec.vectors.BQVectorUtils; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat; import org.elasticsearch.simdvec.ESVectorUtil; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Random; public class VectorScorerTestUtils { @@ -213,53 +219,153 @@ public static void randomVector(Random random, float[] vector, VectorSimilarityF public static void randomInt4Bytes(Random random, byte[] bytes) { for (int i = 0, len = bytes.length; i < len;) { - bytes[i++] = (byte) random.nextInt(0, 0x10); + bytes[i++] = (byte) random.nextInt(0, 16); } } - /** - * Packs unpacked int4 values (one value per byte) into Lucene nibble-packed format (two values per byte) - * written by {@code Lucene104ScalarQuantizedVectorsWriter} (ScalarEncoding#PACKED_NIBBLE format). - *

- * The unpacked input comes from {@link OptimizedScalarQuantizer#scalarQuantize}, which quantizes a float - * vector into one byte per element in natural order: unpacked = [v0, v1, v2, ..., v_{N-1}] where N = dims. - *

- * The packed format pairs elements that are packedLength ({@param unpacked} length / 2) apart. For example, - * with dims=8, unpacked.length is 8 and packedLength is 4: - * - {@code packed[0] = (v0 << 4) | v4} - * - {@code packed[1] = (v1 << 4) | v5} - * - {@code packed[2] = (v2 << 4) | v6} - * - {@code packed[3] = (v3 << 4) | v7} - *

- * Or, visually, - * UNPACKED (8 bytes, natural vector order, one 4-bit value per byte): - * index: 0 1 2 3 4 5 6 7 - * [v0] [v1] [v2] [v3] [v4] [v5] [v6] [v7] - * PACKED (4 bytes, on disk, two 4-bit values per byte): - * index: 0 1 2 3 - * [v0 | v4] [v1 | v5] [v2 | v6] [v3 | v7] - * hi lo hi lo hi lo hi lo - * 7..4 3..0 7..4 3..0 7..4 3..0 7..4 3..0 - */ - public static byte[] packNibbles(byte[] unpacked) { - int packedLength = unpacked.length / 2; - byte[] packed = new byte[packedLength]; - for (int i = 0; i < packedLength; i++) { - packed[i] = (byte) ((unpacked[i] << 4) | (unpacked[i + packedLength] & 0x0F)); - } - return packed; + public static void writePackedVectorWithCorrection( + IndexOutput out, + byte[] packed, + org.apache.lucene.util.quantization.OptimizedScalarQuantizer.QuantizationResult correction + ) throws IOException { + out.writeBytes(packed, 0, packed.length); + out.writeInt(Float.floatToIntBits(correction.lowerInterval())); + out.writeInt(Float.floatToIntBits(correction.upperInterval())); + out.writeInt(Float.floatToIntBits(correction.additionalCorrection())); + out.writeInt(correction.quantizedComponentSum()); } /** - * Unpacks "nibble-packed" int4 values (two values per byte) into a byte[] (one value per byte) + * Creates a disk-backed {@link QuantizedByteVectorValues} for int4 (PACKED_NIBBLE) vectors. + * The data must have been written via {@link #writePackedVectorWithCorrection}. */ - public static byte[] unpackNibbles(byte[] packed, int dims) { - byte[] unpacked = new byte[dims]; - int packedLen = packed.length; - for (int i = 0; i < packedLen; i++) { - unpacked[i] = (byte) ((packed[i] & 0xFF) >> 4); - unpacked[i + packedLen] = (byte) (packed[i] & 0x0F); + public static QuantizedByteVectorValues createDenseInt4VectorValues( + int dims, + int size, + float[] centroid, + float centroidDp, + IndexInput in, + VectorSimilarityFunction sim + ) throws IOException { + var slice = in.slice("values", 0, in.length()); + return new DenseOffHeapInt4VectorValues(dims, size, sim, slice, centroid, centroidDp); + } + + @SuppressForbidden(reason = "require usage of OptimizedScalarQuantizer") + private static org.apache.lucene.util.quantization.OptimizedScalarQuantizer luceneScalarQuantizer(VectorSimilarityFunction sim) { + return new org.apache.lucene.util.quantization.OptimizedScalarQuantizer(sim); + } + + private static class DenseOffHeapInt4VectorValues extends QuantizedByteVectorValues { + final int dimension; + final int size; + final VectorSimilarityFunction similarityFunction; + + final IndexInput slice; + final byte[] vectorValue; + final ByteBuffer byteBuffer; + final int byteSize; + private int lastOrd = -1; + final float[] correctiveValues; + int quantizedComponentSum; + final float[] centroid; + final float centroidDp; + + DenseOffHeapInt4VectorValues( + int dimension, + int size, + VectorSimilarityFunction similarityFunction, + IndexInput slice, + float[] centroid, + float centroidDp + ) { + this.dimension = dimension; + this.size = size; + this.similarityFunction = similarityFunction; + this.slice = slice; + this.centroid = centroid; + this.centroidDp = centroidDp; + this.correctiveValues = new float[3]; + this.byteSize = dimension / 2 + (Float.BYTES * 3) + Integer.BYTES; + this.byteBuffer = ByteBuffer.allocate(dimension / 2); + this.vectorValue = byteBuffer.array(); + } + + @Override + public IndexInput getSlice() { + return slice; + } + + @Override + public org.apache.lucene.util.quantization.OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int vectorOrd) + throws IOException { + if (lastOrd != vectorOrd) { + slice.seek((long) vectorOrd * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), vectorValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = slice.readInt(); + lastOrd = vectorOrd; + } + return new org.apache.lucene.util.quantization.OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + + @Override + public org.apache.lucene.util.quantization.OptimizedScalarQuantizer getQuantizer() { + return luceneScalarQuantizer(similarityFunction); + } + + @Override + public Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding getScalarEncoding() { + return Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.PACKED_NIBBLE; + } + + @Override + public float[] getCentroid() { + return centroid; + } + + @Override + public float getCentroidDP() { + return centroidDp; + } + + @Override + public VectorScorer scorer(float[] query) { + assert false; + return null; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + if (lastOrd == ord) { + return vectorValue; + } + slice.seek((long) ord * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), vectorValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = slice.readInt(); + lastOrd = ord; + return vectorValue; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public QuantizedByteVectorValues copy() throws IOException { + return new DenseOffHeapInt4VectorValues(dimension, size, similarityFunction, slice.clone(), centroid, centroidDp); } - return unpacked; } }