diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerOSQBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerOSQBenchmark.java index bb0ddf3fc554a..fb851c3fdc88d 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerOSQBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerOSQBenchmark.java @@ -107,7 +107,6 @@ public enum VectorImplementation { IndexInput input; float[] scratchScores; - float[] corrections; record VectorData( VectorScorerTestUtils.OSQVectorData[] indexVectors, @@ -194,7 +193,6 @@ void setup(VectorData data) throws IOException { .newESNextOSQVectorsScorer(input, (byte) queryBits, (byte) docBits, dims, data.binaryIndexLength, BULK_SIZE); }; scratchScores = new float[BULK_SIZE]; - corrections = new float[3]; } Path createTempDirectory(String name) throws IOException { @@ -210,26 +208,37 @@ public void teardown() throws IOException { @Benchmark public float[] score() throws IOException { float[] results = new float[NUM_QUERIES * NUM_VECTORS]; + + float[] lowerIntervals = new float[BULK_SIZE]; + float[] upperIntervals = new float[BULK_SIZE]; + int[] sums = new int[BULK_SIZE]; + float[] additional = new float[BULK_SIZE]; + for (int j = 0; j < NUM_QUERIES; j++) { input.seek(0); - for (int i = 0; i < NUM_VECTORS; i++) { - float qDist = scorer.quantizeScore(binaryQueries[j].quantizedVector()); - input.readFloats(corrections, 0, corrections.length); - int addition = Short.toUnsignedInt(input.readShort()); - float score = scorer.score( - binaryQueries[j].lowerInterval(), - binaryQueries[j].upperInterval(), - binaryQueries[j].quantizedComponentSum(), - binaryQueries[j].additionalCorrection(), - similarityFunction, - centroidDp, - corrections[0], - corrections[1], - addition, - corrections[2], - qDist - ); - results[j * NUM_VECTORS + i] = score; + for (int i = 0; i < NUM_VECTORS; i += BULK_SIZE) { + scorer.quantizeScoreBulk(binaryQueries[j].quantizedVector(), BULK_SIZE, scratchScores); + input.readFloats(lowerIntervals, 0, BULK_SIZE); + input.readFloats(upperIntervals, 0, BULK_SIZE); + input.readInts(sums, 0, BULK_SIZE); + input.readFloats(additional, 0, BULK_SIZE); + + for (int b = 0; b < BULK_SIZE; b++) { + float score = scorer.score( + binaryQueries[j].lowerInterval(), + binaryQueries[j].upperInterval(), + binaryQueries[j].quantizedComponentSum(), + binaryQueries[j].additionalCorrection(), + similarityFunction, + centroidDp, + lowerIntervals[b], + upperIntervals[b], + sums[b], + additional[b], + scratchScores[b] + ); + results[j * NUM_VECTORS + i + b] = score; + } } } return results;