diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerOSQBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerOSQBenchmark.java index fb851c3fdc88d..86cecb281f091 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerOSQBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerOSQBenchmark.java @@ -80,7 +80,7 @@ public enum VectorImplementation { @Param({ "384", "768", "1024" }) public int dims; - @Param({ "1", "2", "4" }) + @Param({ "1", "2", "4", "7" }) public byte bits; @Param @@ -137,11 +137,12 @@ static VectorData generateRandomVectorData( } int binaryQueryLength = ESNextDiskBBQVectorsFormat.QuantEncoding.fromBits(bits).getQueryPackedLength(dims); + byte queryBits = bits == 7 ? (byte) 7 : (byte) 4; VectorScorerTestUtils.OSQVectorData[] queryVectors = new VectorScorerTestUtils.OSQVectorData[numVectors]; var query = new float[dims]; for (int i = 0; i < numVectors; i++) { randomVector(random, query, similarityFunction); - queryVectors[i] = createOSQQueryData(query, centroid, quantizer, dims, (byte) 4, binaryQueryLength); + queryVectors[i] = createOSQQueryData(query, centroid, quantizer, dims, queryBits, binaryQueryLength); } return new VectorData(indexVectors, queryVectors, binaryIndexLength, VectorUtil.dotProduct(centroid, centroid)); @@ -185,6 +186,10 @@ void setup(VectorData data) throws IOException { docBits = 4; yield 4; } + case 7 -> { + docBits = 7; + yield 7; + } default -> throw new IllegalArgumentException("Unsupported bits: " + bits); }; scorer = switch (implementation) { diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESNextOSQVectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESNextOSQVectorsScorer.java index cd3318bdaf0c7..3ace9a12117e3 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESNextOSQVectorsScorer.java @@ -46,9 +46,14 @@ public class ESNextOSQVectorsScorer { protected final float[] upperIntervals; protected final int[] targetComponentSums; protected final float[] additionalCorrections; + private final byte[] scratch; public ESNextOSQVectorsScorer(IndexInput in, byte queryBits, byte indexBits, int dimensions, int dataLength, int bulkSize) { - if (queryBits != 4 || (indexBits != 1 && indexBits != 2 && indexBits != 4)) { + if (indexBits == 7) { + if (queryBits != 7) { + throw new IllegalArgumentException("Only symmetric 7-bit query supported for 7-bit index"); + } + } else if (queryBits != 4 || (indexBits != 1 && indexBits != 2 && indexBits != 4)) { throw new IllegalArgumentException("Only asymmetric 4-bit query and 1, 2 or 4-bit index supported"); } this.in = in; @@ -61,6 +66,7 @@ public ESNextOSQVectorsScorer(IndexInput in, byte queryBits, byte indexBits, int this.targetComponentSums = new int[bulkSize]; this.additionalCorrections = new float[bulkSize]; this.bulkSize = bulkSize; + this.scratch = indexBits == 7 ? new byte[dimensions] : null; } public ESNextOSQVectorsScorer(IndexInput in, byte queryBits, byte indexBits, int dimensions, int dataLength) { @@ -88,9 +94,17 @@ public long quantizeScore(byte[] q) throws IOException { return quantized4BitScoreSymmetric(q); } } + if (indexBits == 7) { + return quantized7BitScore(q); + } throw new IllegalArgumentException("Only 1-bit index supported"); } + private long quantized7BitScore(byte[] q) throws IOException { + in.readBytes(scratch, 0, dimensions); + return VectorUtil.dotProduct(scratch, q); + } + private long quantized4BitScoreSymmetric(byte[] q) throws IOException { assert q.length == length : "length mismatch q " + q.length + " vs " + length; assert length % 4 == 0 : "length must be multiple of 4 for 4-bit index length: " + length + " dimensions: " + dimensions; @@ -174,6 +188,12 @@ public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOExce } throw new IllegalArgumentException("Only symmetric 4-bit query supported for 4-bit index"); } + if (indexBits == 7) { + for (int i = 0; i < count; i++) { + scores[i] = quantizeScore(q); + } + return; + } } /** diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSD7Q7ESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSD7Q7ESNextOSQVectorsScorer.java new file mode 100644 index 0000000000000..8ab4c52b1272e --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSD7Q7ESNextOSQVectorsScorer.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.simdvec.internal.vectorization; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.IndexInput; +import org.elasticsearch.simdvec.internal.MemorySegmentES92Int7VectorsScorer; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; + +/** Vectorized scorer for 7-bit symmetric quantized vectors stored as a {@link MemorySegment}. */ +final class MSD7Q7ESNextOSQVectorsScorer extends MemorySegmentESNextOSQVectorsScorer.MemorySegmentScorer { + + private final MemorySegmentES92Int7VectorsScorer int7Scorer; + + MSD7Q7ESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize, MemorySegment memorySegment) { + super(in, dimensions, dataLength, bulkSize, memorySegment); + this.int7Scorer = new MemorySegmentES92Int7VectorsScorer(in, dimensions, bulkSize, memorySegment); + } + + @Override + long quantizeScore(byte[] q) throws IOException { + return int7Scorer.int7DotProduct(q); + } + + @Override + boolean quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException { + int7Scorer.int7DotProductBulk(q, count, scores); + return true; + } + + @Override + float scoreBulk( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores, + int bulkSize + ) throws IOException { + int7Scorer.scoreBulk( + q, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize + ); + float maxScore = Float.NEGATIVE_INFINITY; + for (int i = 0; i < bulkSize; i++) { + if (scores[i] > maxScore) { + maxScore = scores[i]; + } + } + return maxScore; + } +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java index 14e53b4a88fde..9fcf70ebb7c92 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java @@ -49,8 +49,10 @@ public MemorySegmentESNextOSQVectorsScorer( this.scorer = new MSInt4SymmetricESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize, memorySegment); } else if (queryBits == 4 && indexBits == 2) { this.scorer = new MSDibitToInt4ESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize, memorySegment); + } else if (queryBits == 7 && indexBits == 7) { + this.scorer = new MSD7Q7ESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize, memorySegment); } else { - throw new IllegalArgumentException("Only asymmetric 4-bit query and 1-bit index supported"); + throw new IllegalArgumentException("Unsupported query/index bits combination: " + queryBits + "/" + indexBits); } } @@ -147,7 +149,7 @@ public float scoreBulk( } abstract static sealed class MemorySegmentScorer permits MSBitToInt4ESNextOSQVectorsScorer, MSDibitToInt4ESNextOSQVectorsScorer, - MSInt4SymmetricESNextOSQVectorsScorer { + MSInt4SymmetricESNextOSQVectorsScorer, MSD7Q7ESNextOSQVectorsScorer { // TODO: split Panama and Native implementations static final boolean NATIVE_SUPPORTED = NativeAccess.instance().getVectorSimilarityFunctions().isPresent(); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java index 7bc3b504281bd..a87e3626210ce 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java @@ -47,8 +47,7 @@ public ESNextOSQVectorsScorer newESNextOSQVectorsScorer( unwrappedInput = MemorySegmentAccessInputAccess.unwrap(unwrappedInput); if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS && unwrappedInput instanceof MemorySegmentAccessInput msai - && queryBits == 4 - && (indexBits == 1 || indexBits == 2 || indexBits == 4)) { + && ((queryBits == 4 && (indexBits == 1 || indexBits == 2 || indexBits == 4)) || (queryBits == 7 && indexBits == 7))) { MemorySegment ms = msai.segmentSliceOrNull(0, unwrappedInput.length()); if (ms != null) { return new MemorySegmentESNextOSQVectorsScorer(unwrappedInput, queryBits, indexBits, dimension, dataLength, bulkSize, ms); diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ESNextOSQVectorsScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ESNextOSQVectorsScorerTests.java index 6c7cc576097c6..eb8bd664d0196 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ESNextOSQVectorsScorerTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ESNextOSQVectorsScorerTests.java @@ -40,8 +40,8 @@ public class ESNextOSQVectorsScorerTests extends BaseVectorizationTests { private final DirectoryType directoryType; private final byte indexBits; + private final byte queryBits; private final VectorSimilarityFunction similarityFunction; - private static final byte queryBits = 4; public enum DirectoryType { NIOFS, @@ -49,9 +49,15 @@ public enum DirectoryType { SNAP } - public ESNextOSQVectorsScorerTests(DirectoryType directoryType, byte indexBits, VectorSimilarityFunction similarityFunction) { + public ESNextOSQVectorsScorerTests( + DirectoryType directoryType, + byte indexBits, + byte queryBits, + VectorSimilarityFunction similarityFunction + ) { this.directoryType = directoryType; this.indexBits = indexBits; + this.queryBits = queryBits; this.similarityFunction = similarityFunction; } @@ -63,25 +69,26 @@ public void testQuantizeScore() throws Exception { final int length = ESNextDiskBBQVectorsFormat.QuantEncoding.fromBits(indexBits).getDocPackedLength(dimensions); final byte[] vector = new byte[length]; - final int queryBytes = length * (queryBits / indexBits); try (Directory dir = newParametrizedDirectory()) { try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) { for (int i = 0; i < numVectors; i++) { random().nextBytes(vector); + if (indexBits == 7) clampTo7Bit(vector, dimensions); out.writeBytes(vector, 0, length); } CodecUtil.writeFooter(out); } final byte[] query = new byte[queryBytes]; random().nextBytes(query); + if (indexBits == 7) clampTo7Bit(query, dimensions); try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) { // Work on a slice that has just the right number of bytes to make the test fail with an // index-out-of-bounds in case the implementation reads more than the allowed number of // padding bytes. final IndexInput slice = in.slice("test", 0, (long) length * numVectors); - final var defaultScorer = defaultProvider().newESNextOSQVectorsScorer( + final ESNextOSQVectorsScorer defaultScorer = defaultProvider().newESNextOSQVectorsScorer( slice, queryBits, indexBits, @@ -89,7 +96,7 @@ public void testQuantizeScore() throws Exception { length, ESNextOSQVectorsScorer.BULK_SIZE ); - final var panamaScorer = maybePanamaProvider().newESNextOSQVectorsScorer( + final ESNextOSQVectorsScorer panamaScorer = maybePanamaProvider().newESNextOSQVectorsScorer( in, queryBits, indexBits, @@ -135,7 +142,9 @@ public void testScore() throws Exception { final float[] query = new float[dimensions]; randomVector(random(), query, similarityFunction); - final int queryVectorPackedLengthInBytes = indexVectorPackedLengthInBytes * (queryBits / indexBits); + final int queryVectorPackedLengthInBytes = indexBits == 7 + ? ESNextDiskBBQVectorsFormat.QuantEncoding.fromBits(indexBits).getQueryPackedLength(dimensions) + : indexVectorPackedLengthInBytes * (queryBits / indexBits); var queryData = createOSQQueryData(query, centroid, quantizer, dimensions, queryBits, queryVectorPackedLengthInBytes); final float centroidDp = VectorUtil.dotProduct(centroid, centroid); @@ -250,7 +259,9 @@ private void doTestScoreBulk(int bulkSize) throws Exception { } final float[] query = new float[dimensions]; randomVector(random(), query, similarityFunction); - final int queryVectorPackedLengthInBytes = indexVectorPackedLengthInBytes * (queryBits / indexBits); + final int queryVectorPackedLengthInBytes = indexBits == 7 + ? ESNextDiskBBQVectorsFormat.QuantEncoding.fromBits(indexBits).getQueryPackedLength(dimensions) + : indexVectorPackedLengthInBytes * (queryBits / indexBits); var queryData = createOSQQueryData(query, centroid, quantizer, dimensions, queryBits, queryVectorPackedLengthInBytes); final float centroidDp = VectorUtil.dotProduct(centroid, centroid); @@ -334,6 +345,7 @@ public void testScoreBulkWithNegativeInfinityScore() throws Exception { byte[] vector = new byte[length]; for (int i = 0; i < bulkSize; i++) { random().nextBytes(vector); + if (indexBits == 7) clampTo7Bit(vector, dimensions); out.writeBytes(vector, 0, length); } // All-zero corrections: zero bytes are interpreted identically regardless of byte order @@ -344,6 +356,7 @@ public void testScoreBulkWithNegativeInfinityScore() throws Exception { byte[] query = new byte[queryBytes]; random().nextBytes(query); + if (indexBits == 7) clampTo7Bit(query, dimensions); float[] scoresDefault = new float[bulkSize]; float[] scoresPanama = new float[bulkSize]; @@ -402,6 +415,12 @@ public void testScoreBulkWithNegativeInfinityScore() throws Exception { } } + private static void clampTo7Bit(byte[] vector, int dimensions) { + for (int i = 0; i < dimensions; i++) { + vector[i] = (byte) (vector[i] & 0x7F); + } + } + private Directory newParametrizedDirectory() throws IOException { return switch (directoryType) { case NIOFS -> new NIOFSDirectory(createTempDir()); @@ -412,8 +431,14 @@ private Directory newParametrizedDirectory() throws IOException { @ParametersFactory public static Iterable parametersFactory() { - return () -> Stream.of((byte) 1, (byte) 2, (byte) 4) - .flatMap(i -> Arrays.stream(DirectoryType.values()).map(f -> List.of(f, i))) + var bitCombinations = List.of( + List.of((byte) 1, (byte) 4), + List.of((byte) 2, (byte) 4), + List.of((byte) 4, (byte) 4), + List.of((byte) 7, (byte) 7) + ); + return () -> bitCombinations.stream() + .flatMap(bits -> Arrays.stream(DirectoryType.values()).map(d -> List.of(d, bits.get(0), bits.get(1)))) .flatMap(p -> Arrays.stream(VectorSimilarityFunction.values()).map(f -> Stream.concat(p.stream(), Stream.of(f)).toArray())) .iterator(); } diff --git a/libs/simdvec/src/testFixtures/java/org/elasticsearch/simdvec/internal/vectorization/VectorScorerTestUtils.java b/libs/simdvec/src/testFixtures/java/org/elasticsearch/simdvec/internal/vectorization/VectorScorerTestUtils.java index 7ffcb22d6ddab..a98a11c69869c 100644 --- a/libs/simdvec/src/testFixtures/java/org/elasticsearch/simdvec/internal/vectorization/VectorScorerTestUtils.java +++ b/libs/simdvec/src/testFixtures/java/org/elasticsearch/simdvec/internal/vectorization/VectorScorerTestUtils.java @@ -14,7 +14,6 @@ import org.apache.lucene.util.VectorUtil; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat; -import org.elasticsearch.simdvec.ESVectorUtil; import java.io.IOException; import java.util.Random; @@ -79,7 +78,7 @@ public static OSQVectorData createOSQQueryData( centroid ); final byte[] quantizeQuery = new byte[queryVectorPackedLengthInBytes]; - ESVectorUtil.transposeHalfByte(scratch, quantizeQuery); + ESNextDiskBBQVectorsFormat.QuantEncoding.fromBits(queryBits).packQuery(scratch, quantizeQuery); return new OSQVectorData( quantizeQuery, diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index 7d725d2222c48..7452eb328fd8c 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -404,9 +404,9 @@ public static void main(String[] args) throws Exception { private static void checkQuantizeBits(TestConfiguration args) { switch (args.indexType()) { case IVF: - if (args.quantizeBits() == null || !Set.of(1, 2, 4).contains(args.quantizeBits())) { + if (args.quantizeBits() == null || !Set.of(1, 2, 4, 7).contains(args.quantizeBits())) { throw new IllegalArgumentException( - "IVF index type only supports 1, 2 or 4 bits quantization, but got: " + args.quantizeBits() + "IVF index type only supports 1, 2, 4 or 7 bits quantization, but got: " + args.quantizeBits() ); } break; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/DiskBBQBulkWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/DiskBBQBulkWriter.java index eae350812de97..a3a7ab4663423 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/DiskBBQBulkWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/DiskBBQBulkWriter.java @@ -211,6 +211,9 @@ public void writeVectors(QuantizedVectorValues qvv, CheckedIntConsumer ONE_BIT_4BIT_QUERY; case 2 -> TWO_BIT_4BIT_QUERY; case 4 -> FOUR_BIT_SYMMETRIC; + case 7 -> SEVEN_BIT_SYMMETRIC; default -> throw new IllegalArgumentException("Unsupported bits: " + bits); }; } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/DiskBBQBulkWriterTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/DiskBBQBulkWriterTests.java new file mode 100644 index 0000000000000..556bbaba10ca2 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/DiskBBQBulkWriterTests.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.diskbbq; + +import org.apache.lucene.store.ByteBuffersDataOutput; +import org.apache.lucene.store.ByteBuffersIndexOutput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.elasticsearch.common.lucene.store.ByteArrayIndexInput; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +public class DiskBBQBulkWriterTests extends ESTestCase { + + private static final Integer[] VALID_BIT_SIZES = { 1, 2, 4, 7 }; + private static final Integer[] INVALID_BIT_SIZES = { 0, 3, 5, 6, 8, 16 }; + + public void testLargeBitEncodingWritesIntComponentSum() throws Exception { + assertLargeBitEncoding(7); + } + + public void testFromBitSizeValidValues() throws IOException { + int bits = randomFrom(VALID_BIT_SIZES); + try (IndexOutput out = new ByteBuffersIndexOutput(new ByteBuffersDataOutput(), "test", "test")) { + DiskBBQBulkWriter writer = DiskBBQBulkWriter.fromBitSize(bits, 32, out); + assertNotNull(writer); + } + } + + public void testFromBitSizeInvalidValues() throws IOException { + int bits = randomFrom(INVALID_BIT_SIZES); + try (IndexOutput out = new ByteBuffersIndexOutput(new ByteBuffersDataOutput(), "test", "test")) { + expectThrows(IllegalArgumentException.class, () -> DiskBBQBulkWriter.fromBitSize(bits, 32, out)); + } + } + + private void assertLargeBitEncoding(int bits) throws IOException { + int dimensions = randomIntBetween(2, 64); + int bulkSize = randomIntBetween(2, 16); + int numVectors = bulkSize + randomIntBetween(1, bulkSize - 1); // guarantees a bulk block + tail + byte[][] vectors = new byte[numVectors][dimensions]; + OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[numVectors]; + for (int i = 0; i < numVectors; i++) { + random().nextBytes(vectors[i]); + corrections[i] = new OptimizedScalarQuantizer.QuantizationResult( + randomFloat(), + randomFloat(), + randomFloat(), + randomIntBetween(0, 200_000) + ); + } + + ByteBuffersDataOutput buffer = new ByteBuffersDataOutput(); + try (IndexOutput out = new ByteBuffersIndexOutput(buffer, "diskbbq", "diskbbq")) { + DiskBBQBulkWriter writer = DiskBBQBulkWriter.fromBitSize(bits, bulkSize, out); + writer.writeVectors(new TestQuantizedVectorValues(vectors, corrections), null); + } + + try (IndexInput in = new ByteArrayIndexInput("diskbbq", buffer.toArrayCopy())) { + // bulk block: vectors then corrections (lower[], upper[], componentSum[], additional[]) + for (int i = 0; i < bulkSize; i++) { + assertVectorEquals(in, vectors[i], dimensions); + } + for (int i = 0; i < bulkSize; i++) { + assertEquals(corrections[i].lowerInterval(), readFloat(in), 0.0f); + } + for (int i = 0; i < bulkSize; i++) { + assertEquals(corrections[i].upperInterval(), readFloat(in), 0.0f); + } + for (int i = 0; i < bulkSize; i++) { + assertEquals(corrections[i].quantizedComponentSum(), in.readInt()); + } + for (int i = 0; i < bulkSize; i++) { + assertEquals(corrections[i].additionalCorrection(), readFloat(in), 0.0f); + } + // tail: each vector followed by its own correction (lower, upper, additional, componentSum) + for (int i = bulkSize; i < numVectors; i++) { + assertVectorEquals(in, vectors[i], dimensions); + assertEquals(corrections[i].lowerInterval(), readFloat(in), 0.0f); + assertEquals(corrections[i].upperInterval(), readFloat(in), 0.0f); + assertEquals(corrections[i].additionalCorrection(), readFloat(in), 0.0f); + assertEquals(corrections[i].quantizedComponentSum(), in.readInt()); + } + } + } + + private static float readFloat(IndexInput in) throws IOException { + return Float.intBitsToFloat(in.readInt()); + } + + private static void assertVectorEquals(IndexInput in, byte[] expected, int dimensions) throws IOException { + byte[] actual = new byte[dimensions]; + in.readBytes(actual, 0, dimensions); + assertArrayEquals(expected, actual); + } + + private static class TestQuantizedVectorValues implements QuantizedVectorValues { + private final byte[][] vectors; + private final OptimizedScalarQuantizer.QuantizationResult[] corrections; + private int index = -1; + + TestQuantizedVectorValues(byte[][] vectors, OptimizedScalarQuantizer.QuantizationResult[] corrections) { + this.vectors = vectors; + this.corrections = corrections; + } + + @Override + public int count() { + return vectors.length; + } + + @Override + public byte[] next() { + index++; + return vectors[index]; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrections() { + return corrections[index]; + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/QuantEncodingTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/QuantEncodingTests.java index b4f3d57f85c7d..b3fc2e3aabef3 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/QuantEncodingTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/QuantEncodingTests.java @@ -71,4 +71,13 @@ public void testHalfByteAndNibblesPackSize() { assertEquals(8, encoding.getQueryPackedLength(16)); assertEquals(8, encoding.getQueryPackedLength(16)); } + + public void testSevenBitPackSize() { + ESNextDiskBBQVectorsFormat.QuantEncoding encoding = ESNextDiskBBQVectorsFormat.QuantEncoding.SEVEN_BIT_SYMMETRIC; + assertEquals(3, encoding.getDocPackedLength(3)); + assertEquals(3, encoding.getQueryPackedLength(3)); + assertEquals(8, encoding.getDocPackedLength(8)); + assertEquals(8, encoding.getQueryPackedLength(8)); + } + }