diff --git a/docs/changelog/141598.yaml b/docs/changelog/141598.yaml new file mode 100644 index 0000000000000..b8db34e54bfd4 --- /dev/null +++ b/docs/changelog/141598.yaml @@ -0,0 +1,5 @@ +area: Vector Search +issues: [] +pr: 141598 +summary: DiskBBQ - Always block encode doc vectors +type: enhancement diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESNextOSQVectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESNextOSQVectorsScorer.java index 1145a0367a4af..cd3318bdaf0c7 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESNextOSQVectorsScorer.java @@ -236,6 +236,31 @@ public float scoreBulk( float centroidDp, float[] scores ) throws IOException { + return scoreBulk( + q, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize + ); + } + + public float scoreBulk( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores, + int bulkSize + ) throws IOException { + assert bulkSize <= this.bulkSize : "supplied bulkSize > this scorer's bulkSize"; quantizeScoreBulk(q, bulkSize, scores); in.readFloats(lowerIntervals, 0, bulkSize); in.readFloats(upperIntervals, 0, bulkSize); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java index 5e9f7c6958e07..ed61d061b3700 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java @@ -392,7 +392,8 @@ public float scoreBulk( float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, - float[] scores + float[] scores, + int bulkSize ) throws IOException { assert q.length == length * 4; // 128 / 8 == 16 @@ -414,7 +415,8 @@ public float scoreBulk( queryAdditionalCorrection, similarityFunction, centroidDp, - scores + scores, + bulkSize ); } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { return score128Bulk( @@ -424,7 +426,8 @@ public float scoreBulk( queryAdditionalCorrection, similarityFunction, centroidDp, - scores + scores, + bulkSize ); } } @@ -439,7 +442,8 @@ private float score128Bulk( float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, - float[] scores + float[] scores, + int bulkSize ) throws IOException { int limit = FLOAT_SPECIES_128.loopBound(bulkSize); int i = 0; @@ -501,6 +505,21 @@ private float score128Bulk( } } } + if (limit < bulkSize) { + maxScore = scoreTailIndividually( + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize, + limit, + offset, + ay, + ly, + y1, + maxScore + ); + } in.seek(offset + 16L * bulkSize); return maxScore; } @@ -512,7 +531,8 @@ private float score256Bulk( float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, - float[] scores + float[] scores, + int bulkSize ) throws IOException { int limit = FLOAT_SPECIES_256.loopBound(bulkSize); int i = 0; @@ -574,7 +594,82 @@ private float score256Bulk( } } } + if (limit < bulkSize) { + maxScore = scoreTailIndividually( + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize, + limit, + offset, + ay, + ly, + y1, + maxScore + ); + } in.seek(offset + 16L * bulkSize); return maxScore; } + + float scoreTailIndividually( + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores, + int bulkSize, + int limit, + long offset, + float ay, + float ly, + float y1, + float maxScore + ) { + for (int j = limit; j < bulkSize; j++) { + float ax = memorySegment.get( + ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN), + offset + (long) j * Float.BYTES + ); + + float lx = memorySegment.get( + ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN), + offset + 4L * bulkSize + (long) j * Float.BYTES + ) - ax; + + int targetComponentSum = memorySegment.get( + ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN), + offset + 8L * bulkSize + (long) j * Integer.BYTES + ); + + float additionalCorrection = memorySegment.get( + ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN), + offset + 12L * bulkSize + (long) j * Float.BYTES + ); + + float qcDist = scores[j]; + + float res = ax * ay * dimensions + lx * ay * targetComponentSum + ax * ly * y1 + lx * ly * qcDist; + + if (similarityFunction == EUCLIDEAN) { + res = res * -2f + additionalCorrection + queryAdditionalCorrection + 1f; + res = Math.max(1f / res, 0f); + scores[j] = res; + maxScore = Math.max(maxScore, res); + } else { + res = res + queryAdditionalCorrection + additionalCorrection - centroidDp; + + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + res = VectorUtil.scaleMaxInnerProductScore(res); + scores[j] = res; + maxScore = Math.max(maxScore, res); + } else { + res = Math.max((res + 1f) * 0.5f, 0f); + scores[j] = res; + maxScore = Math.max(maxScore, res); + } + } + } + return maxScore; + } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSDibitToInt4ESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSDibitToInt4ESNextOSQVectorsScorer.java index 95af4c419c01b..2c983b36eecb5 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSDibitToInt4ESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSDibitToInt4ESNextOSQVectorsScorer.java @@ -284,7 +284,8 @@ public float scoreBulk( float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, - float[] scores + float[] scores, + int bulkSize ) throws IOException { assert q.length == length * 2; // 128 / 8 == 16 @@ -306,7 +307,8 @@ public float scoreBulk( queryAdditionalCorrection, similarityFunction, centroidDp, - scores + scores, + bulkSize ); } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { return score128Bulk( @@ -316,7 +318,8 @@ public float scoreBulk( queryAdditionalCorrection, similarityFunction, centroidDp, - scores + scores, + bulkSize ); } } @@ -331,7 +334,8 @@ private float score128Bulk( float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, - float[] scores + float[] scores, + int bulkSize ) throws IOException { int limit = FLOAT_SPECIES_128.loopBound(bulkSize); int i = 0; @@ -393,6 +397,22 @@ private float score128Bulk( } } } + if (limit < bulkSize) { + // missing vectors to score + maxScore = scoreTailIndividually( + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize, + limit, + offset, + ay, + ly, + y1, + maxScore + ); + } in.seek(offset + 16L * bulkSize); return maxScore; } @@ -404,7 +424,8 @@ private float score256Bulk( float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, - float[] scores + float[] scores, + int bulkSize ) throws IOException { int limit = FLOAT_SPECIES_256.loopBound(bulkSize); int i = 0; @@ -466,7 +487,84 @@ private float score256Bulk( } } } + if (limit < bulkSize) { + // missing vectors to score + maxScore = scoreTailIndividually( + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize, + limit, + offset, + ay, + ly, + y1, + maxScore + ); + } in.seek(offset + 16L * bulkSize); return maxScore; } + + private float scoreTailIndividually( + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores, + int bulkSize, + int limit, + long offset, + float ay, + float ly, + float y1, + float maxScore + ) { + for (int j = limit; j < bulkSize; j++) { + float ax = memorySegment.get( + ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN), + offset + (long) j * Float.BYTES + ); + + float lx = memorySegment.get( + ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN), + offset + 4L * bulkSize + (long) j * Float.BYTES + ); + lx = (lx - ax) * TWO_BIT_SCALE; + + int targetComponentSum = memorySegment.get( + ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN), + offset + 8L * bulkSize + (long) j * Integer.BYTES + ); + + float additionalCorrection = memorySegment.get( + ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN), + offset + 12L * bulkSize + (long) j * Float.BYTES + ); + + float qcDist = scores[j]; + + float res = ax * ay * dimensions + lx * ay * targetComponentSum + ax * ly * y1 + lx * ly * qcDist; + + if (similarityFunction == EUCLIDEAN) { + res = res * -2f + additionalCorrection + queryAdditionalCorrection + 1f; + res = Math.max(1f / res, 0f); + scores[j] = res; + maxScore = Math.max(maxScore, res); + } else { + res = res + queryAdditionalCorrection + additionalCorrection - centroidDp; + + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + res = VectorUtil.scaleMaxInnerProductScore(res); + scores[j] = res; + maxScore = Math.max(maxScore, res); + } else { + res = Math.max((res + 1f) * 0.5f, 0f); + scores[j] = res; + maxScore = Math.max(maxScore, res); + } + } + } + return maxScore; + } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java index 00a42dd9bd469..f340417f6cefb 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; import java.nio.ByteOrder; import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; @@ -234,7 +235,8 @@ public float scoreBulk( float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, - float[] scores + float[] scores, + int bulkSize ) throws IOException { assert q.length == length; // 128 / 8 == 16 @@ -248,7 +250,8 @@ public float scoreBulk( queryAdditionalCorrection, similarityFunction, centroidDp, - scores + scores, + bulkSize ); } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { return score128Bulk( @@ -259,7 +262,8 @@ public float scoreBulk( queryAdditionalCorrection, similarityFunction, centroidDp, - scores + scores, + bulkSize ); } } @@ -274,7 +278,8 @@ private float score128Bulk( float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, - float[] scores + float[] scores, + int bulkSize ) throws IOException { quantizeScore128Bulk(q, bulkSize, scores); int limit = FLOAT_SPECIES_128.loopBound(bulkSize); @@ -337,6 +342,22 @@ private float score128Bulk( } } } + if (limit < bulkSize) { + // missing vectors to score + maxScore = scoreTailIndividually( + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize, + limit, + offset, + ay, + ly, + y1, + maxScore + ); + } in.seek(offset + 16L * bulkSize); return maxScore; } @@ -349,7 +370,8 @@ private float score256Bulk( float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, - float[] scores + float[] scores, + int bulkSize ) throws IOException { quantizeScore256Bulk(q, bulkSize, scores); int limit = FLOAT_SPECIES_256.loopBound(bulkSize); @@ -412,7 +434,84 @@ private float score256Bulk( } } } + if (limit < bulkSize) { + // missing vectors to score + maxScore = scoreTailIndividually( + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize, + limit, + offset, + ay, + ly, + y1, + maxScore + ); + } in.seek(offset + 16L * bulkSize); return maxScore; } + + private float scoreTailIndividually( + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores, + int bulkSize, + int limit, + long offset, + float ay, + float ly, + float y1, + float maxScore + ) { + for (int j = limit; j < bulkSize; j++) { + float ax = memorySegment.get( + ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN), + offset + (long) j * Float.BYTES + ); + + float lx = memorySegment.get( + ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN), + offset + 4L * bulkSize + (long) j * Float.BYTES + ); + lx = (lx - ax) * FOUR_BIT_SCALE; + + int targetComponentSum = memorySegment.get( + ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN), + offset + 8L * bulkSize + (long) j * Integer.BYTES + ); + + float additionalCorrection = memorySegment.get( + ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN), + offset + 12L * bulkSize + (long) j * Float.BYTES + ); + + float qcDist = scores[j]; + + float res = ax * ay * dimensions + lx * ay * targetComponentSum + ax * ly * y1 + lx * ly * qcDist; + + if (similarityFunction == EUCLIDEAN) { + res = res * -2f + additionalCorrection + queryAdditionalCorrection + 1f; + res = Math.max(1f / res, 0f); + scores[j] = res; + maxScore = Math.max(maxScore, res); + } else { + res = res + queryAdditionalCorrection + additionalCorrection - centroidDp; + + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + res = VectorUtil.scaleMaxInnerProductScore(res); + scores[j] = res; + maxScore = Math.max(maxScore, res); + } else { + res = Math.max((res + 1f) * 0.5f, 0f); + scores[j] = res; + maxScore = Math.max(maxScore, res); + } + } + } + return maxScore; + } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java index f1cfba5f2a4a2..a5c6f1631ea2f 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java @@ -101,6 +101,45 @@ public float scoreBulk( ); } + @Override + public float scoreBulk( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores, + int bulkSize + ) throws IOException { + float score = scorer.scoreBulk( + q, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize + ); + if (score != Float.NEGATIVE_INFINITY) { + return score; + } + return super.scoreBulk( + q, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize + ); + } + abstract static sealed class MemorySegmentScorer permits MSBitToInt4ESNextOSQVectorsScorer, MSDibitToInt4ESNextOSQVectorsScorer, MSInt4SymmetricESNextOSQVectorsScorer { @@ -141,7 +180,7 @@ abstract static sealed class MemorySegmentScorer permits MSBitToInt4ESNextOSQVec abstract boolean quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException; - abstract float scoreBulk( + float scoreBulk( byte[] q, float queryLowerInterval, float queryUpperInterval, @@ -150,6 +189,30 @@ abstract float scoreBulk( VectorSimilarityFunction similarityFunction, float centroidDp, float[] scores + ) throws IOException { + return scoreBulk( + q, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + BULK_SIZE + ); + } + + abstract float scoreBulk( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores, + int bulkSize ) throws IOException; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java index bed8dc64e3625..43e7315fc0eee 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java @@ -704,10 +704,10 @@ public int resetPostingsScorer(PostingMetadata metadata) throws IOException { return vectors; } - private float scoreIndividually() throws IOException { + private float scoreIndividually(int bulkSize) throws IOException { float maxScore = Float.NEGATIVE_INFINITY; // score individually, first the quantized byte chunk - for (int j = 0; j < BULK_SIZE; j++) { + for (int j = 0; j < bulkSize; j++) { int doc = docIdsScratch[j]; if (doc != -1) { float qcDist = osqVectorsScorer.quantizeScore(queryQuantizer.getQuantizedTarget()); @@ -717,14 +717,14 @@ private float scoreIndividually() throws IOException { } } // read in all corrections - indexInput.readFloats(correctionsLower, 0, BULK_SIZE); - indexInput.readFloats(correctionsUpper, 0, BULK_SIZE); - for (int j = 0; j < BULK_SIZE; j++) { + indexInput.readFloats(correctionsLower, 0, bulkSize); + indexInput.readFloats(correctionsUpper, 0, bulkSize); + for (int j = 0; j < bulkSize; j++) { correctionsSum[j] = indexInput.readInt(); } - indexInput.readFloats(correctionsAdd, 0, BULK_SIZE); + indexInput.readFloats(correctionsAdd, 0, bulkSize); // Now apply corrections - for (int j = 0; j < BULK_SIZE; j++) { + for (int j = 0; j < bulkSize; j++) { int doc = docIdsScratch[j]; if (doc != -1) { scores[j] = osqVectorsScorer.score( @@ -748,10 +748,10 @@ private float scoreIndividually() throws IOException { return maxScore; } - private static int docToBulkScore(int[] docIds, Bits acceptDocs) { + private static int docToBulkScore(int[] docIds, Bits acceptDocs, int bulkSize) { assert acceptDocs != null : "acceptDocs must not be null"; - int docToScore = BULK_SIZE; - for (int i = 0; i < BULK_SIZE; i++) { + int docToScore = bulkSize; + for (int i = 0; i < bulkSize; i++) { if (acceptDocs.get(docIds[i]) == false) { docIds[i] = -1; docToScore--; @@ -760,8 +760,8 @@ private static int docToBulkScore(int[] docIds, Bits acceptDocs) { return docToScore; } - private void collectBulk(KnnCollector knnCollector, float[] scores) { - for (int i = 0; i < BULK_SIZE; i++) { + private void collectBulk(KnnCollector knnCollector, float[] scores, int bulkSize) { + for (int i = 0; i < bulkSize; i++) { final int doc = docIdsScratch[i]; if (doc != -1) { knnCollector.collect(doc, scores[i]); @@ -789,7 +789,7 @@ public int visit(KnnCollector knnCollector) throws IOException { for (; i < limit; i += BULK_SIZE) { // read the doc ids readDocIds(BULK_SIZE); - final int docsToBulkScore = acceptDocs == null ? BULK_SIZE : docToBulkScore(docIdsScratch, acceptDocs); + final int docsToBulkScore = acceptDocs == null ? BULK_SIZE : docToBulkScore(docIdsScratch, acceptDocs, BULK_SIZE); if (docsToBulkScore == 0) { indexInput.skipBytes(quantizedByteLength * BULK_SIZE); continue; @@ -797,7 +797,7 @@ public int visit(KnnCollector knnCollector) throws IOException { queryQuantizer.quantizeQueryIfNecessary(); final float maxScore; if (docsToBulkScore < BULK_SIZE / 2) { - maxScore = scoreIndividually(); + maxScore = scoreIndividually(BULK_SIZE); } else { maxScore = osqVectorsScorer.scoreBulk( queryQuantizer.getQuantizedTarget(), @@ -811,40 +811,39 @@ public int visit(KnnCollector knnCollector) throws IOException { ); } if (knnCollector.minCompetitiveSimilarity() < maxScore) { - collectBulk(knnCollector, scores); + collectBulk(knnCollector, scores, BULK_SIZE); } scoredDocs += docsToBulkScore; } - // process tail - // read the doc ids + // bulk process tail if (i < vectors) { - readDocIds(vectors - i); - } - int count = 0; - for (; i < vectors; i++) { - int doc = docIdsScratch[count++]; - if (acceptDocs == null || acceptDocs.get(doc)) { - queryQuantizer.quantizeQueryIfNecessary(); - float qcDist = osqVectorsScorer.quantizeScore(queryQuantizer.getQuantizedTarget()); - indexInput.readFloats(correctiveValues, 0, 3); - final int quantizedComponentSum = indexInput.readInt(); - float score = osqVectorsScorer.score( - queryQuantizer.getQueryCorrections().lowerInterval(), - queryQuantizer.getQueryCorrections().upperInterval(), - queryQuantizer.getQueryCorrections().quantizedComponentSum(), - queryQuantizer.getQueryCorrections().additionalCorrection(), - fieldInfo.getVectorSimilarityFunction(), - centroidDp, - correctiveValues[0], - correctiveValues[1], - quantizedComponentSum, - correctiveValues[2], - qcDist - ); - scoredDocs++; - knnCollector.collect(doc, score); + int tailSize = vectors - i; + readDocIds(tailSize); + final int docsToBulkScore = acceptDocs == null ? tailSize : docToBulkScore(docIdsScratch, acceptDocs, tailSize); + if (docsToBulkScore == 0) { + indexInput.skipBytes(quantizedByteLength * tailSize); } else { - indexInput.skipBytes(quantizedByteLength); + queryQuantizer.quantizeQueryIfNecessary(); + final float maxScore; + if (docsToBulkScore < tailSize / 2) { + maxScore = scoreIndividually(tailSize); + } else { + maxScore = osqVectorsScorer.scoreBulk( + queryQuantizer.getQuantizedTarget(), + queryQuantizer.getQueryCorrections().lowerInterval(), + queryQuantizer.getQueryCorrections().upperInterval(), + queryQuantizer.getQueryCorrections().quantizedComponentSum(), + queryQuantizer.getQueryCorrections().additionalCorrection(), + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + scores, + tailSize + ); + } + if (knnCollector.minCompetitiveSimilarity() < maxScore) { + collectBulk(knnCollector, scores, tailSize); + } + scoredDocs += docsToBulkScore; } } if (scoredDocs > 0) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java index d79dcd5b259b4..e7ded391185a1 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java @@ -234,7 +234,7 @@ public CentroidOffsetAndLength buildAndWritePostingsLists( // write the posting lists final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); final PackedLongValues.Builder lengths = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); - DiskBBQBulkWriter bulkWriter = DiskBBQBulkWriter.fromBitSize(quantEncoding.bits(), BULK_SIZE, postingsOutput, false, true); + DiskBBQBulkWriter bulkWriter = DiskBBQBulkWriter.fromBitSize(quantEncoding.bits(), BULK_SIZE, postingsOutput, true, true); OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors( floatVectorValues, quantEncoding, @@ -386,7 +386,7 @@ public CentroidOffsetAndLength buildAndWritePostingsLists( quantEncoding, fieldInfo.getVectorDimension() ); - DiskBBQBulkWriter bulkWriter = DiskBBQBulkWriter.fromBitSize(quantEncoding.bits(), BULK_SIZE, postingsOutput, false, true); + DiskBBQBulkWriter bulkWriter = DiskBBQBulkWriter.fromBitSize(quantEncoding.bits(), BULK_SIZE, postingsOutput, true, true); final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); // write the posting lists final int[] docIds = new int[maxPostingListSize];