From 0e526cebe2c89dbc2fe6e5401fe8de1776d67461 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Fri, 12 Sep 2025 10:23:18 +0100 Subject: [PATCH 01/15] Native OSQ scoring --- .../benchmark/vector/OSQScorerBenchmark.java | 48 +- libs/native/libraries/build.gradle | 2 +- .../VectorSimilarityFunctions.java | 4 + .../nativeaccess/jdk/JdkVectorLibrary.java | 68 +++ libs/simdvec/native/src/vec/c/aarch64/vec.c | 78 +++ libs/simdvec/native/src/vec/c/amd64/vec.c | 8 + libs/simdvec/native/src/vec/headers/vec.h | 4 + .../simdvec/internal/Similarities.java | 30 + .../MemorySegmentES91OSQVectorsScorer.java | 511 +----------------- ...morySegmentES91PanamaOSQVectorsScorer.java | 434 +++++++++++++++ .../OnHeapES91OSQVectorsScorer.java | 487 +---------------- .../OnHeapES91PanamaOSQVectorsScorer.java | 426 +++++++++++++++ .../MemorySegmentES91OSQVectorsScorer.java | 89 +++ .../OnHeapES91OSQVectorsScorer.java | 91 ++++ 14 files changed, 1269 insertions(+), 1011 deletions(-) create mode 100644 libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91PanamaOSQVectorsScorer.java create mode 100644 libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/OnHeapES91PanamaOSQVectorsScorer.java create mode 100644 libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java create mode 100644 libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/OnHeapES91OSQVectorsScorer.java diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java index 187c72b9bb347..70c19abc0a495 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java @@ -134,10 +134,10 @@ public void teardown() throws IOException { IOUtils.close(dirMmap, inMmap, dirNiofs, inNiofs); } - @Benchmark - public void scoreFromMemorySegmentOnlyVectorMmapScalar(Blackhole bh) throws IOException { - scoreFromMemorySegmentOnlyVector(bh, inMmap, scorerMmap); - } + // @Benchmark + // public void scoreFromMemorySegmentOnlyVectorMmapScalar(Blackhole bh) throws IOException { + // scoreFromMemorySegmentOnlyVector(bh, inMmap, scorerMmap); + // } @Benchmark @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) @@ -145,10 +145,10 @@ public void scoreFromMemorySegmentOnlyVectorMmapVect(Blackhole bh) throws IOExce scoreFromMemorySegmentOnlyVector(bh, inMmap, scorerMmap); } - @Benchmark - public void scoreFromMemorySegmentOnlyVectorNiofsScalar(Blackhole bh) throws IOException { - scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios); - } + // @Benchmark + // public void scoreFromMemorySegmentOnlyVectorNiofsScalar(Blackhole bh) throws IOException { + // scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios); + // } @Benchmark @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) @@ -181,10 +181,10 @@ private void scoreFromMemorySegmentOnlyVector(Blackhole bh, IndexInput in, ES91O } } - @Benchmark - public void scoreFromMemorySegmentOnlyVectorBulkMmapScalar(Blackhole bh) throws IOException { - scoreFromMemorySegmentOnlyVectorBulk(bh, inMmap, scorerMmap); - } + // @Benchmark + // public void scoreFromMemorySegmentOnlyVectorBulkMmapScalar(Blackhole bh) throws IOException { + // scoreFromMemorySegmentOnlyVectorBulk(bh, inMmap, scorerMmap); + // } @Benchmark @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) @@ -192,10 +192,10 @@ public void scoreFromMemorySegmentOnlyVectorBulkMmapVect(Blackhole bh) throws IO scoreFromMemorySegmentOnlyVectorBulk(bh, inMmap, scorerMmap); } - @Benchmark - public void scoreFromMemorySegmentOnlyVectorBulkNiofsScalar(Blackhole bh) throws IOException { - scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios); - } + // @Benchmark + // public void scoreFromMemorySegmentOnlyVectorBulkNiofsScalar(Blackhole bh) throws IOException { + // scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios); + // } @Benchmark @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) @@ -230,10 +230,10 @@ private void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh, IndexInput in, E } } - @Benchmark - public void scoreFromMemorySegmentAllBulkMmapScalar(Blackhole bh) throws IOException { - scoreFromMemorySegmentAllBulk(bh, inMmap, scorerMmap); - } + // @Benchmark + // public void scoreFromMemorySegmentAllBulkMmapScalar(Blackhole bh) throws IOException { + // scoreFromMemorySegmentAllBulk(bh, inMmap, scorerMmap); + // } @Benchmark @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) @@ -241,10 +241,10 @@ public void scoreFromMemorySegmentAllBulkMmapVect(Blackhole bh) throws IOExcepti scoreFromMemorySegmentAllBulk(bh, inMmap, scorerMmap); } - @Benchmark - public void scoreFromMemorySegmentAllBulkNiofsScalar(Blackhole bh) throws IOException { - scoreFromMemorySegmentAllBulk(bh, inNiofs, scorerNfios); - } + // @Benchmark + // public void scoreFromMemorySegmentAllBulkNiofsScalar(Blackhole bh) throws IOException { + // scoreFromMemorySegmentAllBulk(bh, inNiofs, scorerNfios); + // } @Benchmark @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) diff --git a/libs/native/libraries/build.gradle b/libs/native/libraries/build.gradle index 4d94ad6e20c73..0e11406cdc548 100644 --- a/libs/native/libraries/build.gradle +++ b/libs/native/libraries/build.gradle @@ -52,7 +52,7 @@ dependencies { libs "org.elasticsearch:zstd:${zstdVersion}:linux-aarch64" libs "org.elasticsearch:zstd:${zstdVersion}:linux-x86-64" libs "org.elasticsearch:zstd:${zstdVersion}:windows-x86-64" - libs "org.elasticsearch:vec:${vecVersion}@zip" // temporarily comment this out, if testing a locally built native lib + // libs "org.elasticsearch:vec:${vecVersion}@zip" // temporarily comment this out, if testing a locally built native lib } def extractLibs = tasks.register('extractLibs', Copy) { diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java index 4d3f6bc5b2c79..41177b8072c43 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java @@ -70,4 +70,8 @@ public interface VectorSimilarityFunctions { * 4-byte float32 elements. */ MethodHandle squareDistanceHandleFloat32(); + + MethodHandle int4BitDotProductHandle(); + + MethodHandle int4BitDotProductBulkHandle(); } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java index 2c429283d64ef..caf69219e6a79 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java @@ -25,6 +25,7 @@ import static java.lang.foreign.ValueLayout.ADDRESS; import static java.lang.foreign.ValueLayout.JAVA_FLOAT; import static java.lang.foreign.ValueLayout.JAVA_INT; +import static java.lang.foreign.ValueLayout.JAVA_LONG; import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle; public final class JdkVectorLibrary implements VectorLibrary { @@ -36,6 +37,8 @@ public final class JdkVectorLibrary implements VectorLibrary { static final MethodHandle cosf32$mh; static final MethodHandle dotf32$mh; static final MethodHandle sqrf32$mh; + static final MethodHandle int4Bit$mh; + static final MethodHandle int4BitBulk$mh; public static final JdkVectorSimilarityFunctions INSTANCE; @@ -100,6 +103,16 @@ public final class JdkVectorLibrary implements VectorLibrary { LinkerHelperUtil.critical() ); } + int4Bit$mh = downcallHandle( + "int4Bit", + FunctionDescriptor.of(JAVA_LONG, ADDRESS, ADDRESS, JAVA_LONG, JAVA_INT), + LinkerHelperUtil.critical() + ); + int4BitBulk$mh = downcallHandle( + "int4BitBulk", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_LONG, ADDRESS, JAVA_INT, JAVA_INT), + LinkerHelperUtil.critical() + ); INSTANCE = new JdkVectorSimilarityFunctions(); } else { if (caps < 0) { @@ -112,6 +125,8 @@ public final class JdkVectorLibrary implements VectorLibrary { cosf32$mh = null; dotf32$mh = null; sqrf32$mh = null; + int4Bit$mh = null; + int4BitBulk$mh = null; INSTANCE = null; } } catch (Throwable t) { @@ -142,6 +157,34 @@ static int dotProduct7u(MemorySegment a, MemorySegment b, int length) { return dot7u(a, b, length); } + static long int4BitDotProd(MemorySegment a, MemorySegment b, long offset, int length) { + if (a.byteSize() != 4L * length) { + throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + 4L * length); + } + return int4Bit(a, b, offset, length); + } + + private static long int4Bit(MemorySegment a, MemorySegment b, long offset, int length) { + try { + return (long) JdkVectorLibrary.int4Bit$mh.invokeExact(a, b, offset, length); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + static int int4BitDotProdBulk(MemorySegment a, MemorySegment b, long offset, MemorySegment s, int count, int length) { + assert length >= 0; + return int4BitBulk(a, b, offset, s, count, length); + } + + private static int int4BitBulk(MemorySegment a, MemorySegment b, long offset, MemorySegment s, int count, int length) { + try { + return (int) JdkVectorLibrary.int4BitBulk$mh.invokeExact(a, b, offset, s, count, length); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + /** * Computes the square distance of given unsigned int7 byte vectors. * @@ -247,6 +290,8 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) { static final MethodHandle COS_HANDLE_FLOAT32; static final MethodHandle DOT_HANDLE_FLOAT32; static final MethodHandle SQR_HANDLE_FLOAT32; + static final MethodHandle DOT_HANDLE_4BIT; + static final MethodHandle DOT_HANDLE_4BIT_BULK; static { try { @@ -259,6 +304,19 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) { COS_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "cosineF32", mt); DOT_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProductF32", mt); SQR_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistanceF32", mt); + mt = MethodType.methodType(long.class, MemorySegment.class, MemorySegment.class, long.class, int.class); + DOT_HANDLE_4BIT = lookup.findStatic(JdkVectorSimilarityFunctions.class, "int4BitDotProd", mt); + mt = MethodType.methodType( + int.class, + MemorySegment.class, + MemorySegment.class, + long.class, + MemorySegment.class, + int.class, + int.class + ); + + DOT_HANDLE_4BIT_BULK = lookup.findStatic(JdkVectorSimilarityFunctions.class, "int4BitDotProdBulk", mt); } catch (NoSuchMethodException | IllegalAccessException e) { throw new RuntimeException(e); } @@ -288,5 +346,15 @@ public MethodHandle dotProductHandleFloat32() { public MethodHandle squareDistanceHandleFloat32() { return SQR_HANDLE_FLOAT32; } + + @Override + public MethodHandle int4BitDotProductHandle() { + return DOT_HANDLE_4BIT; + } + + @Override + public MethodHandle int4BitDotProductBulkHandle() { + return DOT_HANDLE_4BIT_BULK; + } } } diff --git a/libs/simdvec/native/src/vec/c/aarch64/vec.c b/libs/simdvec/native/src/vec/c/aarch64/vec.c index f3eb7f51ee5d1..c09f42d876305 100644 --- a/libs/simdvec/native/src/vec/c/aarch64/vec.c +++ b/libs/simdvec/native/src/vec/c/aarch64/vec.c @@ -299,3 +299,81 @@ EXPORT float sqrf32(const float *a, const float *b, size_t elementCount) { return result; } + +EXPORT int64_t int4Bit(uint8_t* query, uint8_t* doc, int64_t offset, int length) { + const size_t stride = (length / 8) * 8; + uint64_t dot_q0 = 0; + uint64_t dot_q1 = 0; + uint64_t dot_q2 = 0; + uint64_t dot_q3 = 0; + const uint8_t* doc_idx = doc + offset; + const uint8_t* query_j0 = query; + const uint8_t* query_j1 = query + length; + const uint8_t* query_j2 = query + 2 * length; + const uint8_t* query_j3 = query + 3 * length; + int i = 0; + for (; i < stride; i += 8) { + const uint64_t qv0 = *(const uint64_t*)(query_j0 + i); + const uint64_t qv1 = *(const uint64_t*)(query_j1 + i); + const uint64_t qv2 = *(const uint64_t*)(query_j2 + i); + const uint64_t qv3 = *(const uint64_t*)(query_j3 + i); + const uint64_t yv = *(const uint64_t*)(doc_idx + i); + dot_q0 += __builtin_popcountll(qv0 & yv); + dot_q1 += __builtin_popcountll(qv1 & yv); + dot_q2 += __builtin_popcountll(qv2 & yv); + dot_q3 += __builtin_popcountll(qv3 & yv); + } + for (; i < length; i++) { + const uint8_t qv0 = *(query_j0 + i); + const uint8_t qv1 = *(query_j1 + i); + const uint8_t qv2 = *(query_j2 + i); + const uint8_t qv3 = *(query_j3 + i); + const uint8_t yv = *(doc_idx + i); + dot_q0 += __builtin_popcountll(qv0 & yv); + dot_q1 += __builtin_popcountll(qv1 & yv); + dot_q2 += __builtin_popcountll(qv2 & yv); + dot_q3 += __builtin_popcountll(qv3 & yv); + } + return dot_q0 + (dot_q1 << 1) + (dot_q2 << 2) + (dot_q3 << 3); +} + +EXPORT int32_t int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float32_t* scores, int count, int length) { + const size_t stride = (length / 8) * 8; + const uint8_t* query_j0 = query; + const uint8_t* query_j1 = query + length; + const uint8_t* query_j2 = query + 2 * length; + const uint8_t* query_j3 = query + 3 * length; + // assumption that the query bits are 4, and doc bits are singular + for (size_t idx = 0; idx < count; idx++) { + uint64_t dot_q0 = 0; + uint64_t dot_q1 = 0; + uint64_t dot_q2 = 0; + uint64_t dot_q3 = 0; + const uint8_t* doc_idx = doc + offset + idx * length; + int i = 0; + for (; i < stride; i += 8) { + const uint64_t qv0 = *(const uint64_t*)(query_j0 + i); + const uint64_t qv1 = *(const uint64_t*)(query_j1 + i); + const uint64_t qv2 = *(const uint64_t*)(query_j2 + i); + const uint64_t qv3 = *(const uint64_t*)(query_j3 + i); + const uint64_t yv = *(const uint64_t*)(doc_idx + i); + dot_q0 += __builtin_popcountll(qv0 & yv); + dot_q1 += __builtin_popcountll(qv1 & yv); + dot_q2 += __builtin_popcountll(qv2 & yv); + dot_q3 += __builtin_popcountll(qv3 & yv); + } + for (; i < length; i++) { + const uint8_t qv0 = *(query_j0 + i); + const uint8_t qv1 = *(query_j1 + i); + const uint8_t qv2 = *(query_j2 + i); + const uint8_t qv3 = *(query_j3 + i); + const uint8_t yv = *(doc_idx + i); + dot_q0 += __builtin_popcountll(qv0 & yv); + dot_q1 += __builtin_popcountll(qv1 & yv); + dot_q2 += __builtin_popcountll(qv2 & yv); + dot_q3 += __builtin_popcountll(qv3 & yv); + } + scores[idx] = (float32_t)(dot_q0 + (dot_q1 << 1) + (dot_q2 << 2) + (dot_q3 << 3)); + } + return count; +} diff --git a/libs/simdvec/native/src/vec/c/amd64/vec.c b/libs/simdvec/native/src/vec/c/amd64/vec.c index c6b9154b60660..6d3765c0e9caf 100644 --- a/libs/simdvec/native/src/vec/c/amd64/vec.c +++ b/libs/simdvec/native/src/vec/c/amd64/vec.c @@ -346,3 +346,11 @@ EXPORT float sqrf32(const float *a, const float *b, size_t elementCount) { return result; } + +EXPORT int64_t int4Bit(uint8_t* query, uint8_t* doc, int64_t offset, int length) { + return 0; +} + +EXPORT int32_t int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float32_t* scores, size_t count, size_t dims) { + return 0; +} diff --git a/libs/simdvec/native/src/vec/headers/vec.h b/libs/simdvec/native/src/vec/headers/vec.h index 733aea3165659..132a21c12ffa2 100644 --- a/libs/simdvec/native/src/vec/headers/vec.h +++ b/libs/simdvec/native/src/vec/headers/vec.h @@ -27,3 +27,7 @@ EXPORT float dotf32(const float *a, const float *b, size_t elementCount); EXPORT float sqrf32(const float *a, const float *b, size_t elementCount); +EXPORT int64_t int4Bit(uint8_t* query, uint8_t* doc, int64_t offset, int length); + +EXPORT int32_t int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float32_t* scores, int count, int dims); + diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java index 482bbc8d8cabe..51a7354d2be61 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java @@ -23,6 +23,8 @@ public class Similarities { static final MethodHandle DOT_PRODUCT_7U = DISTANCE_FUNCS.dotProductHandle7u(); static final MethodHandle SQUARE_DISTANCE_7U = DISTANCE_FUNCS.squareDistanceHandle7u(); + static final MethodHandle INT4_BIT_DP = DISTANCE_FUNCS.int4BitDotProductHandle(); + static final MethodHandle INT4_BIT_DP_BULK = DISTANCE_FUNCS.int4BitDotProductBulkHandle(); static int dotProduct7u(MemorySegment a, MemorySegment b, int length) { try { @@ -51,4 +53,32 @@ static int squareDistance7u(MemorySegment a, MemorySegment b, int length) { } } } + + public static long int4BitDotProduct(MemorySegment a, MemorySegment b, long offset, int length) { + try { + return (long) INT4_BIT_DP.invokeExact(a, b, offset, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + public static int int4BitDotProductBulk(MemorySegment a, MemorySegment b, long offset, MemorySegment scores, int count, int length) { + try { + return (int) INT4_BIT_DP_BULK.invokeExact(a, b, offset, scores, count, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java index 1844b4cbd398f..968d203e8b8e5 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java @@ -8,347 +8,27 @@ */ package org.elasticsearch.simdvec.internal.vectorization; -import jdk.incubator.vector.ByteVector; -import jdk.incubator.vector.FloatVector; -import jdk.incubator.vector.IntVector; -import jdk.incubator.vector.LongVector; -import jdk.incubator.vector.ShortVector; -import jdk.incubator.vector.VectorOperators; -import jdk.incubator.vector.VectorSpecies; - import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.util.BitUtil; -import org.apache.lucene.util.VectorUtil; -import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; import java.lang.foreign.MemorySegment; -import java.nio.ByteOrder; - -import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; -import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; - -/** Panamized scorer for quantized vectors stored as a {@link MemorySegment}. */ -public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScorer { - - private static final VectorSpecies INT_SPECIES_128 = IntVector.SPECIES_128; - - private static final VectorSpecies LONG_SPECIES_128 = LongVector.SPECIES_128; - private static final VectorSpecies LONG_SPECIES_256 = LongVector.SPECIES_256; - - private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; - private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256; - private static final VectorSpecies SHORT_SPECIES_128 = ShortVector.SPECIES_128; - private static final VectorSpecies SHORT_SPECIES_256 = ShortVector.SPECIES_256; - - private static final VectorSpecies FLOAT_SPECIES_128 = FloatVector.SPECIES_128; - private static final VectorSpecies FLOAT_SPECIES_256 = FloatVector.SPECIES_256; - - private final MemorySegment memorySegment; +/** Panamaized scorer for quantized vectors stored as a {@link MemorySegment}. */ +public final class MemorySegmentES91OSQVectorsScorer extends MemorySegmentES91PanamaOSQVectorsScorer { public MemorySegmentES91OSQVectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) { - super(in, dimensions); - this.memorySegment = memorySegment; + super(in, dimensions, memorySegment); } @Override public long quantizeScore(byte[] q) throws IOException { - assert q.length == length * 4; - // 128 / 8 == 16 - if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { - if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { - return quantizeScore256(q); - } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { - return quantizeScore128(q); - } - } - return super.quantizeScore(q); - } - - private long quantizeScore256(byte[] q) throws IOException { - long subRet0 = 0; - long subRet1 = 0; - long subRet2 = 0; - long subRet3 = 0; - int i = 0; - long offset = in.getFilePointer(); - if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { - int limit = ByteVector.SPECIES_256.loopBound(length); - var sum0 = LongVector.zero(LONG_SPECIES_256); - var sum1 = LongVector.zero(LONG_SPECIES_256); - var sum2 = LongVector.zero(LONG_SPECIES_256); - var sum3 = LongVector.zero(LONG_SPECIES_256); - for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) { - var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs(); - var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); - sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); - } - subRet0 += sum0.reduceLanes(VectorOperators.ADD); - subRet1 += sum1.reduceLanes(VectorOperators.ADD); - subRet2 += sum2.reduceLanes(VectorOperators.ADD); - subRet3 += sum3.reduceLanes(VectorOperators.ADD); - } - - if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) { - var sum0 = LongVector.zero(LONG_SPECIES_128); - var sum1 = LongVector.zero(LONG_SPECIES_128); - var sum2 = LongVector.zero(LONG_SPECIES_128); - var sum3 = LongVector.zero(LONG_SPECIES_128); - int limit = ByteVector.SPECIES_128.loopBound(length); - for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) { - var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs(); - var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); - sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); - } - subRet0 += sum0.reduceLanes(VectorOperators.ADD); - subRet1 += sum1.reduceLanes(VectorOperators.ADD); - subRet2 += sum2.reduceLanes(VectorOperators.ADD); - subRet3 += sum3.reduceLanes(VectorOperators.ADD); - } - // process scalar tail - in.seek(offset); - for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); - subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); - subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); - subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); - subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); - } - for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); - subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); - subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); - subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); - subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); - } - for (; i < length; i++) { - int dValue = in.readByte() & 0xFF; - subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); - subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); - subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); - subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); - } - return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); - } - - private long quantizeScore128(byte[] q) throws IOException { - long subRet0 = 0; - long subRet1 = 0; - long subRet2 = 0; - long subRet3 = 0; - int i = 0; - long offset = in.getFilePointer(); - - var sum0 = IntVector.zero(INT_SPECIES_128); - var sum1 = IntVector.zero(INT_SPECIES_128); - var sum2 = IntVector.zero(INT_SPECIES_128); - var sum3 = IntVector.zero(INT_SPECIES_128); - int limit = ByteVector.SPECIES_128.loopBound(length); - for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) { - var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); - var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts(); - sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); - sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); - sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); - sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); - } - subRet0 += sum0.reduceLanes(VectorOperators.ADD); - subRet1 += sum1.reduceLanes(VectorOperators.ADD); - subRet2 += sum2.reduceLanes(VectorOperators.ADD); - subRet3 += sum3.reduceLanes(VectorOperators.ADD); - // process scalar tail - in.seek(offset); - for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); - subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); - subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); - subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); - subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); - } - for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); - subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); - subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); - subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); - subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); - } - for (; i < length; i++) { - int dValue = in.readByte() & 0xFF; - subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); - subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); - subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); - subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); - } - return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + return panamaQuantizeScore(q); } @Override public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException { - assert q.length == length * 4; - // 128 / 8 == 16 - if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { - if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { - quantizeScore256Bulk(q, count, scores); - return; - } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { - quantizeScore128Bulk(q, count, scores); - return; - } - } - super.quantizeScoreBulk(q, count, scores); - } - - private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException { - for (int iter = 0; iter < count; iter++) { - long subRet0 = 0; - long subRet1 = 0; - long subRet2 = 0; - long subRet3 = 0; - int i = 0; - long offset = in.getFilePointer(); - - var sum0 = IntVector.zero(INT_SPECIES_128); - var sum1 = IntVector.zero(INT_SPECIES_128); - var sum2 = IntVector.zero(INT_SPECIES_128); - var sum3 = IntVector.zero(INT_SPECIES_128); - int limit = ByteVector.SPECIES_128.loopBound(length); - for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) { - var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); - var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts(); - sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); - sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); - sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); - sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); - } - subRet0 += sum0.reduceLanes(VectorOperators.ADD); - subRet1 += sum1.reduceLanes(VectorOperators.ADD); - subRet2 += sum2.reduceLanes(VectorOperators.ADD); - subRet3 += sum3.reduceLanes(VectorOperators.ADD); - // process scalar tail - in.seek(offset); - for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); - subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); - subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); - subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); - subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); - } - for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); - subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); - subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); - subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); - subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); - } - for (; i < length; i++) { - int dValue = in.readByte() & 0xFF; - subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); - subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); - subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); - subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); - } - scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); - } - } - - private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IOException { - for (int iter = 0; iter < count; iter++) { - long subRet0 = 0; - long subRet1 = 0; - long subRet2 = 0; - long subRet3 = 0; - int i = 0; - long offset = in.getFilePointer(); - if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { - int limit = ByteVector.SPECIES_256.loopBound(length); - var sum0 = LongVector.zero(LONG_SPECIES_256); - var sum1 = LongVector.zero(LONG_SPECIES_256); - var sum2 = LongVector.zero(LONG_SPECIES_256); - var sum3 = LongVector.zero(LONG_SPECIES_256); - for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) { - var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs(); - var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); - sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); - } - subRet0 += sum0.reduceLanes(VectorOperators.ADD); - subRet1 += sum1.reduceLanes(VectorOperators.ADD); - subRet2 += sum2.reduceLanes(VectorOperators.ADD); - subRet3 += sum3.reduceLanes(VectorOperators.ADD); - } - - if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) { - var sum0 = LongVector.zero(LONG_SPECIES_128); - var sum1 = LongVector.zero(LONG_SPECIES_128); - var sum2 = LongVector.zero(LONG_SPECIES_128); - var sum3 = LongVector.zero(LONG_SPECIES_128); - int limit = ByteVector.SPECIES_128.loopBound(length); - for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) { - var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs(); - var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); - sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); - } - subRet0 += sum0.reduceLanes(VectorOperators.ADD); - subRet1 += sum1.reduceLanes(VectorOperators.ADD); - subRet2 += sum2.reduceLanes(VectorOperators.ADD); - subRet3 += sum3.reduceLanes(VectorOperators.ADD); - } - // process scalar tail - in.seek(offset); - for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); - subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); - subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); - subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); - subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); - } - for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); - subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); - subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); - subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); - subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); - } - for (; i < length; i++) { - int dValue = in.readByte() & 0xFF; - subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); - subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); - subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); - subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); - } - scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); - } + panamaQuantizeScoreBulk(q, count, scores); } @Override @@ -362,35 +42,8 @@ public float scoreBulk( float centroidDp, float[] scores ) throws IOException { - assert q.length == length * 4; - // 128 / 8 == 16 - if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { - if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { - return score256Bulk( - q, - queryLowerInterval, - queryUpperInterval, - queryComponentSum, - queryAdditionalCorrection, - similarityFunction, - centroidDp, - scores - ); - } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { - return score128Bulk( - q, - queryLowerInterval, - queryUpperInterval, - queryComponentSum, - queryAdditionalCorrection, - similarityFunction, - centroidDp, - scores - ); - } - } - return super.scoreBulk( - q, + quantizeScoreBulk(q, BULK_SIZE, scores); + return applyCorrectionsBulk( queryLowerInterval, queryUpperInterval, queryComponentSum, @@ -400,154 +53,4 @@ public float scoreBulk( scores ); } - - private float score128Bulk( - byte[] q, - float queryLowerInterval, - float queryUpperInterval, - int queryComponentSum, - float queryAdditionalCorrection, - VectorSimilarityFunction similarityFunction, - float centroidDp, - float[] scores - ) throws IOException { - quantizeScore128Bulk(q, BULK_SIZE, scores); - int limit = FLOAT_SPECIES_128.loopBound(BULK_SIZE); - int i = 0; - long offset = in.getFilePointer(); - float ay = queryLowerInterval; - float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; - float y1 = queryComponentSum; - float maxScore = Float.NEGATIVE_INFINITY; - for (; i < limit; i += FLOAT_SPECIES_128.length()) { - var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); - var lx = FloatVector.fromMemorySegment( - FLOAT_SPECIES_128, - memorySegment, - offset + 4 * BULK_SIZE + i * Float.BYTES, - ByteOrder.LITTLE_ENDIAN - ).sub(ax); - var targetComponentSums = ShortVector.fromMemorySegment( - SHORT_SPECIES_128, - memorySegment, - offset + 8 * BULK_SIZE + i * Short.BYTES, - ByteOrder.LITTLE_ENDIAN - ).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(0xffff).convert(VectorOperators.I2F, 0); - var additionalCorrections = FloatVector.fromMemorySegment( - FLOAT_SPECIES_128, - memorySegment, - offset + 10 * BULK_SIZE + i * Float.BYTES, - ByteOrder.LITTLE_ENDIAN - ); - var qcDist = FloatVector.fromArray(FLOAT_SPECIES_128, scores, i); - // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * - // qcDist; - var res1 = ax.mul(ay).mul(dimensions); - var res2 = lx.mul(ay).mul(targetComponentSums); - var res3 = ax.mul(ly).mul(y1); - var res4 = lx.mul(ly).mul(qcDist); - var res = res1.add(res2).add(res3).add(res4); - // For euclidean, we need to invert the score and apply the additional correction, which is - // assumed to be the squared l2norm of the centroid centered vectors. - if (similarityFunction == EUCLIDEAN) { - res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); - res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0); - maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); - res.intoArray(scores, i); - } else { - // For cosine and max inner product, we need to apply the additional correction, which is - // assumed to be the non-centered dot-product between the vector and the centroid - res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp); - if (similarityFunction == MAXIMUM_INNER_PRODUCT) { - res.intoArray(scores, i); - // not sure how to do it better - for (int j = 0; j < FLOAT_SPECIES_128.length(); j++) { - scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); - maxScore = Math.max(maxScore, scores[i + j]); - } - } else { - res = res.add(1f).mul(0.5f).max(0); - res.intoArray(scores, i); - maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); - } - } - } - in.seek(offset + 14L * BULK_SIZE); - return maxScore; - } - - private float score256Bulk( - byte[] q, - float queryLowerInterval, - float queryUpperInterval, - int queryComponentSum, - float queryAdditionalCorrection, - VectorSimilarityFunction similarityFunction, - float centroidDp, - float[] scores - ) throws IOException { - quantizeScore256Bulk(q, BULK_SIZE, scores); - int limit = FLOAT_SPECIES_256.loopBound(BULK_SIZE); - int i = 0; - long offset = in.getFilePointer(); - float ay = queryLowerInterval; - float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; - float y1 = queryComponentSum; - float maxScore = Float.NEGATIVE_INFINITY; - for (; i < limit; i += FLOAT_SPECIES_256.length()) { - var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); - var lx = FloatVector.fromMemorySegment( - FLOAT_SPECIES_256, - memorySegment, - offset + 4 * BULK_SIZE + i * Float.BYTES, - ByteOrder.LITTLE_ENDIAN - ).sub(ax); - var targetComponentSums = ShortVector.fromMemorySegment( - SHORT_SPECIES_256, - memorySegment, - offset + 8 * BULK_SIZE + i * Short.BYTES, - ByteOrder.LITTLE_ENDIAN - ).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(0xffff).convert(VectorOperators.I2F, 0); - var additionalCorrections = FloatVector.fromMemorySegment( - FLOAT_SPECIES_256, - memorySegment, - offset + 10 * BULK_SIZE + i * Float.BYTES, - ByteOrder.LITTLE_ENDIAN - ); - var qcDist = FloatVector.fromArray(FLOAT_SPECIES_256, scores, i); - // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * - // qcDist; - var res1 = ax.mul(ay).mul(dimensions); - var res2 = lx.mul(ay).mul(targetComponentSums); - var res3 = ax.mul(ly).mul(y1); - var res4 = lx.mul(ly).mul(qcDist); - var res = res1.add(res2).add(res3).add(res4); - // For euclidean, we need to invert the score and apply the additional correction, which is - // assumed to be the squared l2norm of the centroid centered vectors. - if (similarityFunction == EUCLIDEAN) { - res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); - res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0); - maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); - res.intoArray(scores, i); - } else { - // For cosine and max inner product, we need to apply the additional correction, which is - // assumed to be the non-centered dot-product between the vector and the centroid - res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp); - if (similarityFunction == MAXIMUM_INNER_PRODUCT) { - res.intoArray(scores, i); - // not sure how to do it better - for (int j = 0; j < FLOAT_SPECIES_256.length(); j++) { - scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); - maxScore = Math.max(maxScore, scores[i + j]); - } - } else { - res = res.add(1f).mul(0.5f).max(0); - maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); - res.intoArray(scores, i); - } - } - } - in.seek(offset + 14L * BULK_SIZE); - return maxScore; - } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91PanamaOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91PanamaOSQVectorsScorer.java new file mode 100644 index 0000000000000..41cad6e189aff --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91PanamaOSQVectorsScorer.java @@ -0,0 +1,434 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.simdvec.internal.vectorization; + +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.LongVector; +import jdk.incubator.vector.ShortVector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorShape; +import jdk.incubator.vector.VectorSpecies; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.BitUtil; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteOrder; + +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** Panamized scorer for quantized vectors stored as a {@link MemorySegment}. */ +public class MemorySegmentES91PanamaOSQVectorsScorer extends ES91OSQVectorsScorer { + + private static final VectorSpecies INT_SPECIES_128 = IntVector.SPECIES_128; + + private static final VectorSpecies LONG_SPECIES_128 = LongVector.SPECIES_128; + private static final VectorSpecies LONG_SPECIES_256 = LongVector.SPECIES_256; + + private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; + private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256; + + private static final VectorSpecies SHORT_SPECIES_128 = ShortVector.SPECIES_128; + private static final VectorSpecies SHORT_SPECIES_256 = ShortVector.SPECIES_256; + + private static final VectorSpecies FLOAT_SPECIES_128 = FloatVector.SPECIES_128; + private static final VectorSpecies FLOAT_SPECIES_256 = FloatVector.SPECIES_256; + + private static final VectorSpecies FLOAT_SPECIES; + private static final VectorSpecies SHORT_SPECIES; + + static { + // default to platform supported bitsize + FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(PanamaESVectorUtilSupport.VECTOR_BITSIZE)); + SHORT_SPECIES = VectorSpecies.of(short.class, VectorShape.forBitSize(PanamaESVectorUtilSupport.VECTOR_BITSIZE)); + } + + protected final MemorySegment memorySegment; + + public MemorySegmentES91PanamaOSQVectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) { + super(in, dimensions); + this.memorySegment = memorySegment; + } + + protected long panamaQuantizeScore(byte[] q) throws IOException { + assert q.length == length * 4; + // 128 / 8 == 16 + if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { + return quantizeScore256(q); + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { + return quantizeScore128(q); + } + } + return super.quantizeScore(q); + } + + private long quantizeScore256(byte[] q) throws IOException { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + long offset = in.getFilePointer(); + if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { + int limit = ByteVector.SPECIES_256.loopBound(length); + var sum0 = LongVector.zero(LONG_SPECIES_256); + var sum1 = LongVector.zero(LONG_SPECIES_256); + var sum2 = LongVector.zero(LONG_SPECIES_256); + var sum3 = LongVector.zero(LONG_SPECIES_256); + for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + + if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) { + var sum0 = LongVector.zero(LONG_SPECIES_128); + var sum1 = LongVector.zero(LONG_SPECIES_128); + var sum2 = LongVector.zero(LONG_SPECIES_128); + var sum3 = LongVector.zero(LONG_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + // process scalar tail + in.seek(offset); + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + final long value = in.readLong(); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); + } + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + final int value = in.readInt(); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); + } + for (; i < length; i++) { + int dValue = in.readByte() & 0xFF; + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + private long quantizeScore128(byte[] q) throws IOException { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + long offset = in.getFilePointer(); + + var sum0 = IntVector.zero(INT_SPECIES_128); + var sum1 = IntVector.zero(INT_SPECIES_128); + var sum2 = IntVector.zero(INT_SPECIES_128); + var sum3 = IntVector.zero(INT_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) { + var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts(); + sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + // process scalar tail + in.seek(offset); + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + final long value = in.readLong(); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); + } + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + final int value = in.readInt(); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); + } + for (; i < length; i++) { + int dValue = in.readByte() & 0xFF; + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + protected void panamaQuantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException { + assert q.length == length * 4; + // 128 / 8 == 16 + if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { + quantizeScore256Bulk(q, count, scores); + return; + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { + quantizeScore128Bulk(q, count, scores); + return; + } + } + super.quantizeScoreBulk(q, count, scores); + } + + private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException { + for (int iter = 0; iter < count; iter++) { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + long offset = in.getFilePointer(); + + var sum0 = IntVector.zero(INT_SPECIES_128); + var sum1 = IntVector.zero(INT_SPECIES_128); + var sum2 = IntVector.zero(INT_SPECIES_128); + var sum3 = IntVector.zero(INT_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) { + var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts(); + sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + // process scalar tail + in.seek(offset); + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + final long value = in.readLong(); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); + } + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + final int value = in.readInt(); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); + } + for (; i < length; i++) { + int dValue = in.readByte() & 0xFF; + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); + } + scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + } + + private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IOException { + for (int iter = 0; iter < count; iter++) { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + long offset = in.getFilePointer(); + if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { + int limit = ByteVector.SPECIES_256.loopBound(length); + var sum0 = LongVector.zero(LONG_SPECIES_256); + var sum1 = LongVector.zero(LONG_SPECIES_256); + var sum2 = LongVector.zero(LONG_SPECIES_256); + var sum3 = LongVector.zero(LONG_SPECIES_256); + for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + + if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) { + var sum0 = LongVector.zero(LONG_SPECIES_128); + var sum1 = LongVector.zero(LONG_SPECIES_128); + var sum2 = LongVector.zero(LONG_SPECIES_128); + var sum3 = LongVector.zero(LONG_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + // process scalar tail + in.seek(offset); + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + final long value = in.readLong(); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); + } + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + final int value = in.readInt(); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); + } + for (; i < length; i++) { + int dValue = in.readByte() & 0xFF; + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); + } + scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + } + + protected float applyCorrectionsBulk( + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + int limit = FLOAT_SPECIES.loopBound(BULK_SIZE); + int i = 0; + long offset = in.getFilePointer(); + float ay = queryLowerInterval; + float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; + float y1 = queryComponentSum; + float maxScore = Float.NEGATIVE_INFINITY; + for (; i < limit; i += FLOAT_SPECIES.length()) { + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var lx = FloatVector.fromMemorySegment( + FLOAT_SPECIES, + memorySegment, + offset + 4 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ).sub(ax); + var targetComponentSums = ShortVector.fromMemorySegment( + SHORT_SPECIES, + memorySegment, + offset + 8 * BULK_SIZE + i * Short.BYTES, + ByteOrder.LITTLE_ENDIAN + ).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(0xffff).convert(VectorOperators.I2F, 0); + var additionalCorrections = FloatVector.fromMemorySegment( + FLOAT_SPECIES, + memorySegment, + offset + 10 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ); + var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i); + // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * + // qcDist; + var res1 = ax.mul(ay).mul(dimensions); + var res2 = lx.mul(ay).mul(targetComponentSums); + var res3 = ax.mul(ly).mul(y1); + var res4 = lx.mul(ly).mul(qcDist); + var res = res1.add(res2).add(res3).add(res4); + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); + res = FloatVector.broadcast(FLOAT_SPECIES, 1).div(res).max(0); + maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); + res.intoArray(scores, i); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + res.intoArray(scores, i); + // not sure how to do it better + for (int j = 0; j < FLOAT_SPECIES.length(); j++) { + scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); + maxScore = Math.max(maxScore, scores[i + j]); + } + } else { + res = res.add(1f).mul(0.5f).max(0); + maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); + res.intoArray(scores, i); + } + } + } + in.seek(offset + 14L * BULK_SIZE); + return maxScore; + } +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/OnHeapES91OSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/OnHeapES91OSQVectorsScorer.java index 7a992af6b06de..8c396ef781926 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/OnHeapES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/OnHeapES91OSQVectorsScorer.java @@ -8,348 +8,26 @@ */ package org.elasticsearch.simdvec.internal.vectorization; -import jdk.incubator.vector.ByteVector; -import jdk.incubator.vector.FloatVector; -import jdk.incubator.vector.IntVector; -import jdk.incubator.vector.LongVector; -import jdk.incubator.vector.VectorOperators; -import jdk.incubator.vector.VectorSpecies; - import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.util.BitUtil; -import org.apache.lucene.util.VectorUtil; -import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; -import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; -import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; - /** Panamized scorer for quantized vectors stored as a {@link IndexInput}. */ -public final class OnHeapES91OSQVectorsScorer extends ES91OSQVectorsScorer { - - private static final VectorSpecies INT_SPECIES_128 = IntVector.SPECIES_128; - private static final VectorSpecies INT_SPECIES_256 = IntVector.SPECIES_256; - - private static final VectorSpecies LONG_SPECIES_128 = LongVector.SPECIES_128; - private static final VectorSpecies LONG_SPECIES_256 = LongVector.SPECIES_256; - - private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; - private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256; - - private static final VectorSpecies FLOAT_SPECIES_128 = FloatVector.SPECIES_128; - private static final VectorSpecies FLOAT_SPECIES_256 = FloatVector.SPECIES_256; - - private final byte[] bytes; +public final class OnHeapES91OSQVectorsScorer extends OnHeapES91PanamaOSQVectorsScorer { public OnHeapES91OSQVectorsScorer(IndexInput in, int dimensions) { super(in, dimensions); - bytes = new byte[BULK_SIZE * length]; } @Override public long quantizeScore(byte[] q) throws IOException { - assert q.length == length * 4; - // 128 / 8 == 16 - if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { - if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { - return quantizeScore256(q); - } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { - return quantizeScore128(q); - } - } - return super.quantizeScore(q); - } - - private long quantizeScore256(byte[] q) throws IOException { - in.readBytes(bytes, 0, length); - long subRet0 = 0; - long subRet1 = 0; - long subRet2 = 0; - long subRet3 = 0; - int i = 0; - if (length >= BYTE_SPECIES_256.vectorByteSize() * 2) { - int limit = BYTE_SPECIES_256.loopBound(length); - var sum0 = LongVector.zero(LONG_SPECIES_256); - var sum1 = LongVector.zero(LONG_SPECIES_256); - var sum2 = LongVector.zero(LONG_SPECIES_256); - var sum3 = LongVector.zero(LONG_SPECIES_256); - for (; i < limit; i += BYTE_SPECIES_256.length()) { - var vd = ByteVector.fromArray(BYTE_SPECIES_256, bytes, i).reinterpretAsLongs(); - var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs(); - sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); - } - subRet0 += sum0.reduceLanes(VectorOperators.ADD); - subRet1 += sum1.reduceLanes(VectorOperators.ADD); - subRet2 += sum2.reduceLanes(VectorOperators.ADD); - subRet3 += sum3.reduceLanes(VectorOperators.ADD); - } - - if (length - i >= BYTE_SPECIES_128.vectorByteSize()) { - var sum0 = LongVector.zero(LONG_SPECIES_128); - var sum1 = LongVector.zero(LONG_SPECIES_128); - var sum2 = LongVector.zero(LONG_SPECIES_128); - var sum3 = LongVector.zero(LONG_SPECIES_128); - int limit = ByteVector.SPECIES_128.loopBound(length); - for (; i < limit; i += BYTE_SPECIES_128.length()) { - var vd = ByteVector.fromArray(BYTE_SPECIES_128, bytes, i).reinterpretAsLongs(); - var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs(); - sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); - } - subRet0 += sum0.reduceLanes(VectorOperators.ADD); - subRet1 += sum1.reduceLanes(VectorOperators.ADD); - subRet2 += sum2.reduceLanes(VectorOperators.ADD); - subRet3 += sum3.reduceLanes(VectorOperators.ADD); - } - // process scalar tail - for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = (long) BitUtil.VH_LE_LONG.get(bytes, i); - subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); - subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); - subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); - subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); - } - for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = (int) BitUtil.VH_LE_INT.get(bytes, i); - subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); - subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); - subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); - subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); - } - for (; i < length; i++) { - int dValue = bytes[i] & 0xFF; - subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); - subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); - subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); - subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); - } - return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); - } - - private long quantizeScore128(byte[] q) throws IOException { - in.readBytes(bytes, 0, length); - long subRet0 = 0; - long subRet1 = 0; - long subRet2 = 0; - long subRet3 = 0; - var sum0 = IntVector.zero(INT_SPECIES_128); - var sum1 = IntVector.zero(INT_SPECIES_128); - var sum2 = IntVector.zero(INT_SPECIES_128); - var sum3 = IntVector.zero(INT_SPECIES_128); - int limit = BYTE_SPECIES_128.loopBound(length); - int i = 0; - for (; i < limit; i += BYTE_SPECIES_128.length()) { - var vd = ByteVector.fromArray(BYTE_SPECIES_128, bytes, i).reinterpretAsInts(); - var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts(); - sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); - sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); - sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); - sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); - } - subRet0 += sum0.reduceLanes(VectorOperators.ADD); - subRet1 += sum1.reduceLanes(VectorOperators.ADD); - subRet2 += sum2.reduceLanes(VectorOperators.ADD); - subRet3 += sum3.reduceLanes(VectorOperators.ADD); - // process scalar tail - for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = (long) BitUtil.VH_LE_LONG.get(bytes, i); - subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); - subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); - subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); - subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); - } - for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = (int) BitUtil.VH_LE_INT.get(bytes, i); - subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); - subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); - subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); - subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); - } - for (; i < length; i++) { - int dValue = bytes[i] & 0xFF; - subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); - subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); - subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); - subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); - } - return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + return panamaQuantizeScore(q); } @Override public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException { - assert q.length == length * 4; - // 128 / 8 == 16 - if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { - if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { - quantizeScore256Bulk(q, count, scores); - return; - } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { - quantizeScore128Bulk(q, count, scores); - return; - } - } - super.quantizeScoreBulk(q, count, scores); - } - - private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException { - int j = 0; - for (; j < count - 15; j += BULK_SIZE) { - in.readBytes(bytes, 0, BULK_SIZE * length); - for (int iter = 0; iter < BULK_SIZE; iter++) { - long subRet0 = 0; - long subRet1 = 0; - long subRet2 = 0; - long subRet3 = 0; - var sum0 = IntVector.zero(INT_SPECIES_128); - var sum1 = IntVector.zero(INT_SPECIES_128); - var sum2 = IntVector.zero(INT_SPECIES_128); - var sum3 = IntVector.zero(INT_SPECIES_128); - int limit = ByteVector.SPECIES_128.loopBound(length); - int i = 0; - for (; i < limit; i += ByteVector.SPECIES_128.length()) { - var vd = ByteVector.fromArray(BYTE_SPECIES_128, bytes, iter * length + i).reinterpretAsInts(); - var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts(); - sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); - sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); - sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); - sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); - } - subRet0 += sum0.reduceLanes(VectorOperators.ADD); - subRet1 += sum1.reduceLanes(VectorOperators.ADD); - subRet2 += sum2.reduceLanes(VectorOperators.ADD); - subRet3 += sum3.reduceLanes(VectorOperators.ADD); - // process scalar tail - for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = (long) BitUtil.VH_LE_LONG.get(bytes, iter * length + i); - subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); - subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); - subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); - subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); - } - for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = (int) BitUtil.VH_LE_INT.get(bytes, iter * length + i); - subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); - subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); - subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); - subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); - } - for (; i < length; i++) { - int dValue = bytes[iter * length + i] & 0xFF; - subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); - subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); - subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); - subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); - } - scores[j + iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); - } - } - for (; j < count; j++) { - scores[j] = quantizeScore128(q); - } - } - - private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IOException { - int j = 0; - for (; j < count - 15; j += BULK_SIZE) { - in.readBytes(bytes, 0, BULK_SIZE * length); - for (int iter = 0; iter < BULK_SIZE; iter++) { - long subRet0 = 0; - long subRet1 = 0; - long subRet2 = 0; - long subRet3 = 0; - int i = 0; - if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { - int limit = ByteVector.SPECIES_256.loopBound(length); - var sum0 = LongVector.zero(LONG_SPECIES_256); - var sum1 = LongVector.zero(LONG_SPECIES_256); - var sum2 = LongVector.zero(LONG_SPECIES_256); - var sum3 = LongVector.zero(LONG_SPECIES_256); - for (; i < limit; i += ByteVector.SPECIES_256.length()) { - var vd = ByteVector.fromArray(BYTE_SPECIES_256, bytes, iter * length + i).reinterpretAsLongs(); - var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs(); - sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); - } - subRet0 += sum0.reduceLanes(VectorOperators.ADD); - subRet1 += sum1.reduceLanes(VectorOperators.ADD); - subRet2 += sum2.reduceLanes(VectorOperators.ADD); - subRet3 += sum3.reduceLanes(VectorOperators.ADD); - } - - if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) { - var sum0 = LongVector.zero(LONG_SPECIES_128); - var sum1 = LongVector.zero(LONG_SPECIES_128); - var sum2 = LongVector.zero(LONG_SPECIES_128); - var sum3 = LongVector.zero(LONG_SPECIES_128); - int limit = ByteVector.SPECIES_128.loopBound(length); - for (; i < limit; i += ByteVector.SPECIES_128.length()) { - var vd = ByteVector.fromArray(BYTE_SPECIES_128, bytes, iter * length + i).reinterpretAsLongs(); - var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs(); - sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); - sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); - } - subRet0 += sum0.reduceLanes(VectorOperators.ADD); - subRet1 += sum1.reduceLanes(VectorOperators.ADD); - subRet2 += sum2.reduceLanes(VectorOperators.ADD); - subRet3 += sum3.reduceLanes(VectorOperators.ADD); - } - // process scalar tail - for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = (long) BitUtil.VH_LE_LONG.get(bytes, iter * length + i); - subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); - subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); - subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); - subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); - } - for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = (int) BitUtil.VH_LE_INT.get(bytes, i); - subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, iter * length + i) & value); - subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); - subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); - subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); - } - for (; i < length; i++) { - int dValue = bytes[iter * length + i] & 0xFF; - subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); - subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); - subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); - subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); - } - scores[j + iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); - } - } - for (; j < count; j++) { - scores[j] = quantizeScore256(q); - } + panamaQuantizeScoreBulk(q, count, scores); } @Override @@ -363,35 +41,8 @@ public float scoreBulk( float centroidDp, float[] scores ) throws IOException { - assert q.length == length * 4; - // 128 / 8 == 16 - if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { - if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { - return score256Bulk( - q, - queryLowerInterval, - queryUpperInterval, - queryComponentSum, - queryAdditionalCorrection, - similarityFunction, - centroidDp, - scores - ); - } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { - return score128Bulk( - q, - queryLowerInterval, - queryUpperInterval, - queryComponentSum, - queryAdditionalCorrection, - similarityFunction, - centroidDp, - scores - ); - } - } - return super.scoreBulk( - q, + quantizeScoreBulk(q, BULK_SIZE, scores); + return applyCorrectionsBulk( queryLowerInterval, queryUpperInterval, queryComponentSum, @@ -401,132 +52,4 @@ public float scoreBulk( scores ); } - - private float score128Bulk( - byte[] q, - float queryLowerInterval, - float queryUpperInterval, - int queryComponentSum, - float queryAdditionalCorrection, - VectorSimilarityFunction similarityFunction, - float centroidDp, - float[] scores - ) throws IOException { - quantizeScore128Bulk(q, BULK_SIZE, scores); - in.readFloats(lowerIntervals, 0, BULK_SIZE); - in.readFloats(upperIntervals, 0, BULK_SIZE); - for (int i = 0; i < BULK_SIZE; i++) { - targetComponentSums[i] = Short.toUnsignedInt(in.readShort()); - } - in.readFloats(additionalCorrections, 0, BULK_SIZE); - int limit = FLOAT_SPECIES_128.loopBound(BULK_SIZE); - int i = 0; - float ay = queryLowerInterval; - float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; - float y1 = queryComponentSum; - float maxScore = Float.NEGATIVE_INFINITY; - for (; i < limit; i += FLOAT_SPECIES_128.length()) { - var ax = FloatVector.fromArray(FLOAT_SPECIES_128, lowerIntervals, i); - var lx = FloatVector.fromArray(FLOAT_SPECIES_128, upperIntervals, i).sub(ax); - var targetComponentSumsVect = IntVector.fromArray(INT_SPECIES_128, targetComponentSums, i).convert(VectorOperators.I2F, 0); - var additionalCorrectionsVect = FloatVector.fromArray(FLOAT_SPECIES_128, additionalCorrections, i); - var qcDist = FloatVector.fromArray(FLOAT_SPECIES_128, scores, i); - // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * - // qcDist; - var res1 = ax.mul(ay).mul(dimensions); - var res2 = lx.mul(ay).mul(targetComponentSumsVect); - var res3 = ax.mul(ly).mul(y1); - var res4 = lx.mul(ly).mul(qcDist); - var res = res1.add(res2).add(res3).add(res4); - // For euclidean, we need to invert the score and apply the additional correction, which is - // assumed to be the squared l2norm of the centroid centered vectors. - if (similarityFunction == EUCLIDEAN) { - res = res.mul(-2).add(additionalCorrectionsVect).add(queryAdditionalCorrection).add(1f); - res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0); - maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); - res.intoArray(scores, i); - } else { - // For cosine and max inner product, we need to apply the additional correction, which is - // assumed to be the non-centered dot-product between the vector and the centroid - res = res.add(additionalCorrectionsVect).add(queryAdditionalCorrection).sub(centroidDp); - if (similarityFunction == MAXIMUM_INNER_PRODUCT) { - res.intoArray(scores, i); - // not sure how to do it better - for (int j = 0; j < FLOAT_SPECIES_128.length(); j++) { - scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); - maxScore = Math.max(maxScore, scores[i + j]); - } - } else { - res = res.add(1f).mul(0.5f).max(0); - res.intoArray(scores, i); - maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); - } - } - } - return maxScore; - } - - private float score256Bulk( - byte[] q, - float queryLowerInterval, - float queryUpperInterval, - int queryComponentSum, - float queryAdditionalCorrection, - VectorSimilarityFunction similarityFunction, - float centroidDp, - float[] scores - ) throws IOException { - quantizeScore256Bulk(q, BULK_SIZE, scores); - in.readFloats(lowerIntervals, 0, BULK_SIZE); - in.readFloats(upperIntervals, 0, BULK_SIZE); - for (int i = 0; i < BULK_SIZE; i++) { - targetComponentSums[i] = Short.toUnsignedInt(in.readShort()); - } - in.readFloats(additionalCorrections, 0, BULK_SIZE); - int limit = FLOAT_SPECIES_256.loopBound(BULK_SIZE); - int i = 0; - float ay = queryLowerInterval; - float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; - float y1 = queryComponentSum; - float maxScore = Float.NEGATIVE_INFINITY; - for (; i < limit; i += FLOAT_SPECIES_256.length()) { - var ax = FloatVector.fromArray(FLOAT_SPECIES_256, lowerIntervals, i); - var lx = FloatVector.fromArray(FLOAT_SPECIES_256, upperIntervals, i).sub(ax); - var targetComponentSumsVect = IntVector.fromArray(INT_SPECIES_256, targetComponentSums, i).convert(VectorOperators.I2F, 0); - var additionalCorrectionsVect = FloatVector.fromArray(FLOAT_SPECIES_256, additionalCorrections, i); - var qcDist = FloatVector.fromArray(FLOAT_SPECIES_256, scores, i); - // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * - // qcDist; - var res1 = ax.mul(ay).mul(dimensions); - var res2 = lx.mul(ay).mul(targetComponentSumsVect); - var res3 = ax.mul(ly).mul(y1); - var res4 = lx.mul(ly).mul(qcDist); - var res = res1.add(res2).add(res3).add(res4); - // For euclidean, we need to invert the score and apply the additional correction, which is - // assumed to be the squared l2norm of the centroid centered vectors. - if (similarityFunction == EUCLIDEAN) { - res = res.mul(-2).add(additionalCorrectionsVect).add(queryAdditionalCorrection).add(1f); - res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0); - maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); - res.intoArray(scores, i); - } else { - // For cosine and max inner product, we need to apply the additional correction, which is - // assumed to be the non-centered dot-product between the vector and the centroid - res = res.add(queryAdditionalCorrection).add(additionalCorrectionsVect).sub(centroidDp); - if (similarityFunction == MAXIMUM_INNER_PRODUCT) { - res.intoArray(scores, i); - // not sure how to do it better - for (int j = 0; j < FLOAT_SPECIES_256.length(); j++) { - scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); - maxScore = Math.max(maxScore, scores[i + j]); - } - } else { - res = res.add(1f).mul(0.5f).max(0); - maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); - res.intoArray(scores, i); - } - } - } - return maxScore; - } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/OnHeapES91PanamaOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/OnHeapES91PanamaOSQVectorsScorer.java new file mode 100644 index 0000000000000..257c1b435349c --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/OnHeapES91PanamaOSQVectorsScorer.java @@ -0,0 +1,426 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.simdvec.internal.vectorization; + +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.LongVector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorShape; +import jdk.incubator.vector.VectorSpecies; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.BitUtil; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** Panamized scorer for quantized vectors stored as a {@link IndexInput}. */ +public class OnHeapES91PanamaOSQVectorsScorer extends ES91OSQVectorsScorer { + + private static final VectorSpecies INT_SPECIES_128 = IntVector.SPECIES_128; + private static final VectorSpecies INT_SPECIES_256 = IntVector.SPECIES_256; + + private static final VectorSpecies LONG_SPECIES_128 = LongVector.SPECIES_128; + private static final VectorSpecies LONG_SPECIES_256 = LongVector.SPECIES_256; + + private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; + private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256; + + private static final VectorSpecies FLOAT_SPECIES_128 = FloatVector.SPECIES_128; + private static final VectorSpecies FLOAT_SPECIES_256 = FloatVector.SPECIES_256; + + private static final VectorSpecies FLOAT_SPECIES; + private static final VectorSpecies INT_SPECIES; + + static { + // default to platform supported bitsize + FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(PanamaESVectorUtilSupport.VECTOR_BITSIZE)); + INT_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(PanamaESVectorUtilSupport.VECTOR_BITSIZE)); + } + + protected final byte[] bytes; + + public OnHeapES91PanamaOSQVectorsScorer(IndexInput in, int dimensions) { + super(in, dimensions); + bytes = new byte[BULK_SIZE * length]; + } + + public long panamaQuantizeScore(byte[] q) throws IOException { + assert q.length == length * 4; + // 128 / 8 == 16 + if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { + return quantizeScore256(q); + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { + return quantizeScore128(q); + } + } + return super.quantizeScore(q); + } + + private long quantizeScore256(byte[] q) throws IOException { + in.readBytes(bytes, 0, length); + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + if (length >= BYTE_SPECIES_256.vectorByteSize() * 2) { + int limit = BYTE_SPECIES_256.loopBound(length); + var sum0 = LongVector.zero(LONG_SPECIES_256); + var sum1 = LongVector.zero(LONG_SPECIES_256); + var sum2 = LongVector.zero(LONG_SPECIES_256); + var sum3 = LongVector.zero(LONG_SPECIES_256); + for (; i < limit; i += BYTE_SPECIES_256.length()) { + var vd = ByteVector.fromArray(BYTE_SPECIES_256, bytes, i).reinterpretAsLongs(); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs(); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + + if (length - i >= BYTE_SPECIES_128.vectorByteSize()) { + var sum0 = LongVector.zero(LONG_SPECIES_128); + var sum1 = LongVector.zero(LONG_SPECIES_128); + var sum2 = LongVector.zero(LONG_SPECIES_128); + var sum3 = LongVector.zero(LONG_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += BYTE_SPECIES_128.length()) { + var vd = ByteVector.fromArray(BYTE_SPECIES_128, bytes, i).reinterpretAsLongs(); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs(); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + // process scalar tail + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + final long value = (long) BitUtil.VH_LE_LONG.get(bytes, i); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); + } + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + final int value = (int) BitUtil.VH_LE_INT.get(bytes, i); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); + } + for (; i < length; i++) { + int dValue = bytes[i] & 0xFF; + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + private long quantizeScore128(byte[] q) throws IOException { + in.readBytes(bytes, 0, length); + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + var sum0 = IntVector.zero(INT_SPECIES_128); + var sum1 = IntVector.zero(INT_SPECIES_128); + var sum2 = IntVector.zero(INT_SPECIES_128); + var sum3 = IntVector.zero(INT_SPECIES_128); + int limit = BYTE_SPECIES_128.loopBound(length); + int i = 0; + for (; i < limit; i += BYTE_SPECIES_128.length()) { + var vd = ByteVector.fromArray(BYTE_SPECIES_128, bytes, i).reinterpretAsInts(); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts(); + sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + // process scalar tail + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + final long value = (long) BitUtil.VH_LE_LONG.get(bytes, i); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); + } + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + final int value = (int) BitUtil.VH_LE_INT.get(bytes, i); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); + } + for (; i < length; i++) { + int dValue = bytes[i] & 0xFF; + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + public void panamaQuantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException { + assert q.length == length * 4; + // 128 / 8 == 16 + if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { + quantizeScore256Bulk(q, count, scores); + return; + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { + quantizeScore128Bulk(q, count, scores); + return; + } + } + super.quantizeScoreBulk(q, count, scores); + } + + private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException { + int j = 0; + for (; j < count - 15; j += BULK_SIZE) { + in.readBytes(bytes, 0, BULK_SIZE * length); + for (int iter = 0; iter < BULK_SIZE; iter++) { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + var sum0 = IntVector.zero(INT_SPECIES_128); + var sum1 = IntVector.zero(INT_SPECIES_128); + var sum2 = IntVector.zero(INT_SPECIES_128); + var sum3 = IntVector.zero(INT_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + int i = 0; + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + var vd = ByteVector.fromArray(BYTE_SPECIES_128, bytes, iter * length + i).reinterpretAsInts(); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts(); + sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + // process scalar tail + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + final long value = (long) BitUtil.VH_LE_LONG.get(bytes, iter * length + i); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); + } + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + final int value = (int) BitUtil.VH_LE_INT.get(bytes, iter * length + i); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); + } + for (; i < length; i++) { + int dValue = bytes[iter * length + i] & 0xFF; + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); + } + scores[j + iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + } + for (; j < count; j++) { + scores[j] = quantizeScore128(q); + } + } + + private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IOException { + int j = 0; + for (; j < count - 15; j += BULK_SIZE) { + in.readBytes(bytes, 0, BULK_SIZE * length); + for (int iter = 0; iter < BULK_SIZE; iter++) { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { + int limit = ByteVector.SPECIES_256.loopBound(length); + var sum0 = LongVector.zero(LONG_SPECIES_256); + var sum1 = LongVector.zero(LONG_SPECIES_256); + var sum2 = LongVector.zero(LONG_SPECIES_256); + var sum3 = LongVector.zero(LONG_SPECIES_256); + for (; i < limit; i += ByteVector.SPECIES_256.length()) { + var vd = ByteVector.fromArray(BYTE_SPECIES_256, bytes, iter * length + i).reinterpretAsLongs(); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs(); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + + if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) { + var sum0 = LongVector.zero(LONG_SPECIES_128); + var sum1 = LongVector.zero(LONG_SPECIES_128); + var sum2 = LongVector.zero(LONG_SPECIES_128); + var sum3 = LongVector.zero(LONG_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + var vd = ByteVector.fromArray(BYTE_SPECIES_128, bytes, iter * length + i).reinterpretAsLongs(); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs(); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + // process scalar tail + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + final long value = (long) BitUtil.VH_LE_LONG.get(bytes, iter * length + i); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); + } + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + final int value = (int) BitUtil.VH_LE_INT.get(bytes, i); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, iter * length + i) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); + } + for (; i < length; i++) { + int dValue = bytes[iter * length + i] & 0xFF; + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); + } + scores[j + iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + } + for (; j < count; j++) { + scores[j] = quantizeScore256(q); + } + } + + protected float applyCorrectionsBulk( + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + in.readFloats(lowerIntervals, 0, BULK_SIZE); + in.readFloats(upperIntervals, 0, BULK_SIZE); + for (int i = 0; i < BULK_SIZE; i++) { + targetComponentSums[i] = Short.toUnsignedInt(in.readShort()); + } + in.readFloats(additionalCorrections, 0, BULK_SIZE); + int limit = FLOAT_SPECIES.loopBound(BULK_SIZE); + int i = 0; + long offset = in.getFilePointer(); + float ay = queryLowerInterval; + float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; + float y1 = queryComponentSum; + float maxScore = Float.NEGATIVE_INFINITY; + for (; i < limit; i += FLOAT_SPECIES.length()) { + var ax = FloatVector.fromArray(FLOAT_SPECIES, lowerIntervals, i); + var lx = FloatVector.fromArray(FLOAT_SPECIES, upperIntervals, i).sub(ax); + var targetComponentSumsVec = IntVector.fromArray(INT_SPECIES, targetComponentSums, i).convert(VectorOperators.I2F, 0); + var additionalCorrectionsVec = FloatVector.fromArray(FLOAT_SPECIES, additionalCorrections, i); + var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i); + // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * + // qcDist; + var res1 = ax.mul(ay).mul(dimensions); + var res2 = lx.mul(ay).mul(targetComponentSumsVec); + var res3 = ax.mul(ly).mul(y1); + var res4 = lx.mul(ly).mul(qcDist); + var res = res1.add(res2).add(res3).add(res4); + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + res = res.mul(-2).add(additionalCorrectionsVec).add(queryAdditionalCorrection).add(1f); + res = FloatVector.broadcast(FLOAT_SPECIES, 1).div(res).max(0); + maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); + res.intoArray(scores, i); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + res = res.add(queryAdditionalCorrection).add(additionalCorrectionsVec).sub(centroidDp); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + res.intoArray(scores, i); + // not sure how to do it better + for (int j = 0; j < FLOAT_SPECIES.length(); j++) { + scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); + maxScore = Math.max(maxScore, scores[i + j]); + } + } else { + res = res.add(1f).mul(0.5f).max(0); + maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); + res.intoArray(scores, i); + } + } + } + return maxScore; + } + +} diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java new file mode 100644 index 0000000000000..4471b1f759811 --- /dev/null +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.simdvec.internal.vectorization; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.IndexInput; +import org.elasticsearch.nativeaccess.NativeAccess; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; + +import static org.elasticsearch.simdvec.internal.Similarities.int4BitDotProduct; +import static org.elasticsearch.simdvec.internal.Similarities.int4BitDotProductBulk; + +/** Native scorer for quantized vectors stored as an {@link IndexInput}. */ +public final class MemorySegmentES91OSQVectorsScorer extends MemorySegmentES91PanamaOSQVectorsScorer { + + private static final boolean NATIVE_SUPPORTED = NativeAccess.instance().getVectorSimilarityFunctions().isPresent(); + + public MemorySegmentES91OSQVectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) { + super(in, dimensions, memorySegment); + } + + @Override + public long quantizeScore(byte[] q) throws IOException { + assert q.length == length * 4; + if (NATIVE_SUPPORTED) { + return nativeQuantizeScore(q); + } else { + return panamaQuantizeScore(q); + } + } + + private long nativeQuantizeScore(byte[] q) throws IOException { + long initialOffset = in.getFilePointer(); + MemorySegment query = MemorySegment.ofArray(q); + long qScore = int4BitDotProduct(query, memorySegment, initialOffset, length); + in.skipBytes(length); + return qScore; + } + + @Override + public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException { + assert q.length == length * 4; + // 128 / 8 == 16 + if (NATIVE_SUPPORTED) { + nativeQuantizeScoreBulk(q, count, scores); + } else { + panamaQuantizeScoreBulk(q, count, scores); + } + } + + private void nativeQuantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException { + long initialOffset = in.getFilePointer(); + MemorySegment query = MemorySegment.ofArray(q); + MemorySegment scoresSegment = MemorySegment.ofArray(scores); + int4BitDotProductBulk(query, memorySegment, initialOffset, scoresSegment, count, length); + in.skipBytes(count * length); + } + + @Override + public float scoreBulk( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + quantizeScoreBulk(q, BULK_SIZE, scores); + return applyCorrectionsBulk( + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores + ); + } +} diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/OnHeapES91OSQVectorsScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/OnHeapES91OSQVectorsScorer.java new file mode 100644 index 0000000000000..0edcb1cefb9ae --- /dev/null +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/OnHeapES91OSQVectorsScorer.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.simdvec.internal.vectorization; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.IndexInput; +import org.elasticsearch.nativeaccess.NativeAccess; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; + +import static org.elasticsearch.simdvec.internal.Similarities.int4BitDotProduct; +import static org.elasticsearch.simdvec.internal.Similarities.int4BitDotProductBulk; + +/** Panamized scorer for quantized vectors stored as a {@link IndexInput}. */ +public final class OnHeapES91OSQVectorsScorer extends OnHeapES91PanamaOSQVectorsScorer { + + private static final boolean NATIVE_SUPPORTED = NativeAccess.instance().getVectorSimilarityFunctions().isPresent(); + + public OnHeapES91OSQVectorsScorer(IndexInput in, int dimensions) { + super(in, dimensions); + } + + @Override + public long quantizeScore(byte[] q) throws IOException { + if (NATIVE_SUPPORTED) { + return nativeQuantizeScore(q); + } else { + return panamaQuantizeScore(q); + } + } + + private long nativeQuantizeScore(byte[] q) throws IOException { + in.readBytes(bytes, 0, length); + MemorySegment query = MemorySegment.ofArray(q); + MemorySegment memorySegment = MemorySegment.ofArray(bytes).asSlice(0, length); + return int4BitDotProduct(query, memorySegment, 0L, length); + } + + @Override + public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException { + if (NATIVE_SUPPORTED) { + nativeQuantizeScoreBulk(q, count, scores); + } else { + panamaQuantizeScoreBulk(q, count, scores); + } + } + + private void nativeQuantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException { + int j = 0; + MemorySegment scoresSegment = MemorySegment.ofArray(scores); + for (; j < count - 15; j += BULK_SIZE) { + in.readBytes(bytes, 0, BULK_SIZE * length); + MemorySegment query = MemorySegment.ofArray(q); + MemorySegment memorySegment = MemorySegment.ofArray(bytes); + int4BitDotProductBulk(query, memorySegment, 0L, scoresSegment.asSlice(j, BULK_SIZE), count, length); + } + for (; j < count; j++) { + scores[j] = quantizeScore(q); + } + } + + @Override + public float scoreBulk( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + quantizeScoreBulk(q, BULK_SIZE, scores); + return applyCorrectionsBulk( + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores + ); + } +} From ebd61aa8404ecdbbec4cfc60637f24336708d9f4 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Fri, 12 Sep 2025 10:34:18 +0100 Subject: [PATCH 02/15] Update docs/changelog/134623.yaml --- docs/changelog/134623.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/134623.yaml diff --git a/docs/changelog/134623.yaml b/docs/changelog/134623.yaml new file mode 100644 index 0000000000000..f7b402d1dc899 --- /dev/null +++ b/docs/changelog/134623.yaml @@ -0,0 +1,5 @@ +pr: 134623 +summary: Native OSQ scoring +area: Vector Search +type: enhancement +issues: [] From 203f9d629d9a5caa142e1564062d3234ff349ead Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Fri, 12 Sep 2025 17:21:45 +0100 Subject: [PATCH 03/15] Add test and return void for bulk --- .../nativeaccess/jdk/JdkVectorLibrary.java | 12 +- .../jdk/JDKVectorLibraryInt4BitTests.java | 118 ++++++++++++++++++ libs/simdvec/native/src/vec/c/aarch64/vec.c | 3 +- libs/simdvec/native/src/vec/c/amd64/vec.c | 2 +- libs/simdvec/native/src/vec/headers/vec.h | 2 +- .../simdvec/internal/Similarities.java | 4 +- 6 files changed, 129 insertions(+), 12 deletions(-) create mode 100644 libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4BitTests.java diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java index caf69219e6a79..2dd1e585e46a3 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java @@ -110,7 +110,7 @@ public final class JdkVectorLibrary implements VectorLibrary { ); int4BitBulk$mh = downcallHandle( "int4BitBulk", - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_LONG, ADDRESS, JAVA_INT, JAVA_INT), + FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_LONG, ADDRESS, JAVA_INT, JAVA_INT), LinkerHelperUtil.critical() ); INSTANCE = new JdkVectorSimilarityFunctions(); @@ -172,14 +172,14 @@ private static long int4Bit(MemorySegment a, MemorySegment b, long offset, int l } } - static int int4BitDotProdBulk(MemorySegment a, MemorySegment b, long offset, MemorySegment s, int count, int length) { + static void int4BitDotProdBulk(MemorySegment a, MemorySegment b, long offset, MemorySegment s, int count, int length) { assert length >= 0; - return int4BitBulk(a, b, offset, s, count, length); + int4BitBulk(a, b, offset, s, count, length); } - private static int int4BitBulk(MemorySegment a, MemorySegment b, long offset, MemorySegment s, int count, int length) { + private static void int4BitBulk(MemorySegment a, MemorySegment b, long offset, MemorySegment s, int count, int length) { try { - return (int) JdkVectorLibrary.int4BitBulk$mh.invokeExact(a, b, offset, s, count, length); + JdkVectorLibrary.int4BitBulk$mh.invokeExact(a, b, offset, s, count, length); } catch (Throwable t) { throw new AssertionError(t); } @@ -307,7 +307,7 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) { mt = MethodType.methodType(long.class, MemorySegment.class, MemorySegment.class, long.class, int.class); DOT_HANDLE_4BIT = lookup.findStatic(JdkVectorSimilarityFunctions.class, "int4BitDotProd", mt); mt = MethodType.methodType( - int.class, + void.class, MemorySegment.class, MemorySegment.class, long.class, diff --git a/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4BitTests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4BitTests.java new file mode 100644 index 0000000000000..8a4dcc445f25e --- /dev/null +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4BitTests.java @@ -0,0 +1,118 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.nativeaccess.jdk; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.nativeaccess.VectorSimilarityFunctionsTests; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.lang.foreign.MemorySegment; + +public class JDKVectorLibraryInt4BitTests extends VectorSimilarityFunctionsTests { + + public JDKVectorLibraryInt4BitTests(int size) { + super(size); + } + + @BeforeClass + public static void beforeClass() { + VectorSimilarityFunctionsTests.setup(); + } + + @AfterClass + public static void afterClass() { + VectorSimilarityFunctionsTests.cleanup(); + } + + @ParametersFactory + public static Iterable parametersFactory() { + return VectorSimilarityFunctionsTests.parametersFactory(); + } + + private static int discretize(int value, int bucket) { + return ((value + (bucket - 1)) / bucket) * bucket; + } + + public void testInt4Bin() { + assumeTrue(notSupportedMsg(), supported()); + final int length = discretize(size, 64) / 8; + final int numVecs = randomIntBetween(2, 101); + var values = new byte[numVecs][length]; + var segment = arena.allocate((long) numVecs * length); + for (int i = 0; i < numVecs; i++) { + random().nextBytes(values[i]); + MemorySegment.copy(MemorySegment.ofArray(values[i]), 0L, segment, (long) i * length, length); + } + + final int loopTimes = 1000; + byte[] query = new byte[4 * length]; + float[] scores = new float[numVecs]; + float[] scoresExpected = new float[numVecs]; + var querySegment = arena.allocate(4L * length); + for (int i = 0; i < loopTimes; i++) { + int ord = randomInt(numVecs - 1); + long offset = (long) ord * length; + random().nextBytes(query); + MemorySegment.copy(MemorySegment.ofArray(query), 0L, querySegment, 0, 4 * length); + for (int j = 0; j < numVecs; j++) { + scoresExpected[j] = int4BitScalar(query, values[j], length); + } + assertEquals(scoresExpected[ord], (float) int4Bit(querySegment, segment, offset, length), 0.0f); + int4BitBulk(querySegment, segment, 0L, MemorySegment.ofArray(scores), numVecs, length); + assertArrayEquals(scoresExpected, scores, 0.0f); + } + } + + long int4Bit(MemorySegment a, MemorySegment b, long offset, int length) { + try { + return (long) getVectorDistance().int4BitDotProductHandle().invokeExact(a, b, offset, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + void int4BitBulk(MemorySegment a, MemorySegment b, long offset, MemorySegment scores, int count, int length) { + try { + getVectorDistance().int4BitDotProductBulkHandle().invokeExact(a, b, offset, scores, count, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + /** Computes the dot product of the given vectors a and b. */ + static long int4BitScalar(byte[] a, byte[] b, int length) { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + for (int r = 0; r < length; r++) { + final byte value = b[r]; + subRet0 += Integer.bitCount((a[r] & value) & 0xFF); + subRet1 += Integer.bitCount((a[r + length] & value) & 0xFF); + subRet2 += Integer.bitCount((a[r + 2 * length] & value) & 0xFF); + subRet3 += Integer.bitCount((a[r + 3 * length] & value) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } +} diff --git a/libs/simdvec/native/src/vec/c/aarch64/vec.c b/libs/simdvec/native/src/vec/c/aarch64/vec.c index c09f42d876305..355a28ab28eb3 100644 --- a/libs/simdvec/native/src/vec/c/aarch64/vec.c +++ b/libs/simdvec/native/src/vec/c/aarch64/vec.c @@ -337,7 +337,7 @@ EXPORT int64_t int4Bit(uint8_t* query, uint8_t* doc, int64_t offset, int length) return dot_q0 + (dot_q1 << 1) + (dot_q2 << 2) + (dot_q3 << 3); } -EXPORT int32_t int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float32_t* scores, int count, int length) { +EXPORT void int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float32_t* scores, int count, int length) { const size_t stride = (length / 8) * 8; const uint8_t* query_j0 = query; const uint8_t* query_j1 = query + length; @@ -375,5 +375,4 @@ EXPORT int32_t int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float32 } scores[idx] = (float32_t)(dot_q0 + (dot_q1 << 1) + (dot_q2 << 2) + (dot_q3 << 3)); } - return count; } diff --git a/libs/simdvec/native/src/vec/c/amd64/vec.c b/libs/simdvec/native/src/vec/c/amd64/vec.c index 6d3765c0e9caf..8114ff63f1b87 100644 --- a/libs/simdvec/native/src/vec/c/amd64/vec.c +++ b/libs/simdvec/native/src/vec/c/amd64/vec.c @@ -351,6 +351,6 @@ EXPORT int64_t int4Bit(uint8_t* query, uint8_t* doc, int64_t offset, int length) return 0; } -EXPORT int32_t int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float32_t* scores, size_t count, size_t dims) { +EXPORT void int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float32_t* scores, size_t count, size_t dims) { return 0; } diff --git a/libs/simdvec/native/src/vec/headers/vec.h b/libs/simdvec/native/src/vec/headers/vec.h index 132a21c12ffa2..1a88bfad0fc0c 100644 --- a/libs/simdvec/native/src/vec/headers/vec.h +++ b/libs/simdvec/native/src/vec/headers/vec.h @@ -29,5 +29,5 @@ EXPORT float sqrf32(const float *a, const float *b, size_t elementCount); EXPORT int64_t int4Bit(uint8_t* query, uint8_t* doc, int64_t offset, int length); -EXPORT int32_t int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float32_t* scores, int count, int dims); +EXPORT void int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float32_t* scores, int count, int dims); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java index 51a7354d2be61..9f2390b6c1094 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java @@ -68,9 +68,9 @@ public static long int4BitDotProduct(MemorySegment a, MemorySegment b, long offs } } - public static int int4BitDotProductBulk(MemorySegment a, MemorySegment b, long offset, MemorySegment scores, int count, int length) { + public static void int4BitDotProductBulk(MemorySegment a, MemorySegment b, long offset, MemorySegment scores, int count, int length) { try { - return (int) INT4_BIT_DP_BULK.invokeExact(a, b, offset, scores, count, length); + INT4_BIT_DP_BULK.invokeExact(a, b, offset, scores, count, length); } catch (Throwable e) { if (e instanceof Error err) { throw err; From 1993bef5abf12847b04cfeb5eb67d5b04562bf0a Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Mon, 15 Sep 2025 10:34:55 +0100 Subject: [PATCH 04/15] iter --- libs/native/libraries/build.gradle | 2 +- libs/simdvec/native/src/vec/c/amd64/vec.c | 6 ++++-- .../vectorization/MemorySegmentES91OSQVectorsScorer.java | 9 ++++++++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/libs/native/libraries/build.gradle b/libs/native/libraries/build.gradle index 0e11406cdc548..4d94ad6e20c73 100644 --- a/libs/native/libraries/build.gradle +++ b/libs/native/libraries/build.gradle @@ -52,7 +52,7 @@ dependencies { libs "org.elasticsearch:zstd:${zstdVersion}:linux-aarch64" libs "org.elasticsearch:zstd:${zstdVersion}:linux-x86-64" libs "org.elasticsearch:zstd:${zstdVersion}:windows-x86-64" - // libs "org.elasticsearch:vec:${vecVersion}@zip" // temporarily comment this out, if testing a locally built native lib + libs "org.elasticsearch:vec:${vecVersion}@zip" // temporarily comment this out, if testing a locally built native lib } def extractLibs = tasks.register('extractLibs', Copy) { diff --git a/libs/simdvec/native/src/vec/c/amd64/vec.c b/libs/simdvec/native/src/vec/c/amd64/vec.c index 8114ff63f1b87..eb58cf430432e 100644 --- a/libs/simdvec/native/src/vec/c/amd64/vec.c +++ b/libs/simdvec/native/src/vec/c/amd64/vec.c @@ -348,9 +348,11 @@ EXPORT float sqrf32(const float *a, const float *b, size_t elementCount) { } EXPORT int64_t int4Bit(uint8_t* query, uint8_t* doc, int64_t offset, int length) { - return 0; + // signal to the caller this is not supported + return -1; } EXPORT void int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float32_t* scores, size_t count, size_t dims) { - return 0; + // signal to the caller this is not supported + scores[0] = -1; } diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java index 4471b1f759811..f81413641aafd 100644 --- a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java @@ -41,6 +41,9 @@ private long nativeQuantizeScore(byte[] q) throws IOException { long initialOffset = in.getFilePointer(); MemorySegment query = MemorySegment.ofArray(q); long qScore = int4BitDotProduct(query, memorySegment, initialOffset, length); + if (qScore == -1) { + return panamaQuantizeScore(q); + } in.skipBytes(length); return qScore; } @@ -61,7 +64,11 @@ private void nativeQuantizeScoreBulk(byte[] q, int count, float[] scores) throws MemorySegment query = MemorySegment.ofArray(q); MemorySegment scoresSegment = MemorySegment.ofArray(scores); int4BitDotProductBulk(query, memorySegment, initialOffset, scoresSegment, count, length); - in.skipBytes(count * length); + if (scores[0] == -1) { + panamaQuantizeScoreBulk(q, count, scores); + } else { + in.skipBytes(count * length); + } } @Override From b147d396d22326bb47ee8b1e3c754c281df276d9 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Mon, 15 Sep 2025 11:03:23 +0100 Subject: [PATCH 05/15] iter --- libs/simdvec/native/src/vec/c/amd64/vec.c | 2 +- libs/simdvec/native/src/vec/headers/vec.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/simdvec/native/src/vec/c/amd64/vec.c b/libs/simdvec/native/src/vec/c/amd64/vec.c index eb58cf430432e..8197484fe2a17 100644 --- a/libs/simdvec/native/src/vec/c/amd64/vec.c +++ b/libs/simdvec/native/src/vec/c/amd64/vec.c @@ -352,7 +352,7 @@ EXPORT int64_t int4Bit(uint8_t* query, uint8_t* doc, int64_t offset, int length) return -1; } -EXPORT void int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float32_t* scores, size_t count, size_t dims) { +EXPORT void int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float_t* scores, int count, int dims) { // signal to the caller this is not supported scores[0] = -1; } diff --git a/libs/simdvec/native/src/vec/headers/vec.h b/libs/simdvec/native/src/vec/headers/vec.h index 1a88bfad0fc0c..dbc3a62d1368a 100644 --- a/libs/simdvec/native/src/vec/headers/vec.h +++ b/libs/simdvec/native/src/vec/headers/vec.h @@ -29,5 +29,5 @@ EXPORT float sqrf32(const float *a, const float *b, size_t elementCount); EXPORT int64_t int4Bit(uint8_t* query, uint8_t* doc, int64_t offset, int length); -EXPORT void int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float32_t* scores, int count, int dims); +EXPORT void int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float_t* scores, int count, int dims); From 68a3d9b331bbcff6803ad3c1fce9e41215522c02 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Mon, 15 Sep 2025 11:13:40 +0100 Subject: [PATCH 06/15] iter --- .../nativeaccess/VectorSimilarityFunctionsTests.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java index 3d8433bf36487..12dbd407b1982 100644 --- a/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java @@ -74,7 +74,8 @@ public boolean supported() { && ((arch.equals("aarch64") && (osName.startsWith("Mac") || osName.equals("Linux"))) || (arch.equals("amd64") && osName.equals("Linux")))) { assertThat(vectorSimilarityFunctions, isPresent()); - return true; + // only implemented in this architecture + return arch.equals("aarch64") && (osName.startsWith("Mac")); } else { assertThat(vectorSimilarityFunctions, not(isPresent())); return false; From da9aa0e292329aa75d0060f0772b08db57287836 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Mon, 15 Sep 2025 11:18:54 +0100 Subject: [PATCH 07/15] iter --- .../VectorSimilarityFunctionsTests.java | 3 +-- .../jdk/JDKVectorLibraryInt4BitTests.java | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java index 12dbd407b1982..3d8433bf36487 100644 --- a/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java @@ -74,8 +74,7 @@ public boolean supported() { && ((arch.equals("aarch64") && (osName.startsWith("Mac") || osName.equals("Linux"))) || (arch.equals("amd64") && osName.equals("Linux")))) { assertThat(vectorSimilarityFunctions, isPresent()); - // only implemented in this architecture - return arch.equals("aarch64") && (osName.startsWith("Mac")); + return true; } else { assertThat(vectorSimilarityFunctions, not(isPresent())); return false; diff --git a/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4BitTests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4BitTests.java index 8a4dcc445f25e..c25fc2e31f5ee 100644 --- a/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4BitTests.java +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4BitTests.java @@ -17,6 +17,9 @@ import java.lang.foreign.MemorySegment; +import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; +import static org.hamcrest.Matchers.not; + public class JDKVectorLibraryInt4BitTests extends VectorSimilarityFunctionsTests { public JDKVectorLibraryInt4BitTests(int size) { @@ -38,6 +41,18 @@ public static Iterable parametersFactory() { return VectorSimilarityFunctionsTests.parametersFactory(); } + + @Override + public boolean supported() { + if (super.supported()) { + var arch = System.getProperty("os.arch"); + var osName = System.getProperty("os.name"); + // only implemented in this architecture + return arch.equals("aarch64") && (osName.startsWith("Mac")); + } + return false; + } + private static int discretize(int value, int bucket) { return ((value + (bucket - 1)) / bucket) * bucket; } From 2e1b33e6bfa03a2ffca23ba5a9b75e2c792b7c81 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Mon, 15 Sep 2025 11:20:49 +0100 Subject: [PATCH 08/15] iter --- .../nativeaccess/jdk/JDKVectorLibraryInt4BitTests.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4BitTests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4BitTests.java index c25fc2e31f5ee..6ca9708783810 100644 --- a/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4BitTests.java +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt4BitTests.java @@ -17,9 +17,6 @@ import java.lang.foreign.MemorySegment; -import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; -import static org.hamcrest.Matchers.not; - public class JDKVectorLibraryInt4BitTests extends VectorSimilarityFunctionsTests { public JDKVectorLibraryInt4BitTests(int size) { @@ -41,7 +38,6 @@ public static Iterable parametersFactory() { return VectorSimilarityFunctionsTests.parametersFactory(); } - @Override public boolean supported() { if (super.supported()) { From c16a04ed16ee9fb2bfdb1597926bdd674b584255 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Mon, 15 Sep 2025 12:29:38 +0100 Subject: [PATCH 09/15] iter --- libs/simdvec/native/src/vec/c/aarch64/vec.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/simdvec/native/src/vec/c/aarch64/vec.c b/libs/simdvec/native/src/vec/c/aarch64/vec.c index 355a28ab28eb3..a7b23a6137764 100644 --- a/libs/simdvec/native/src/vec/c/aarch64/vec.c +++ b/libs/simdvec/native/src/vec/c/aarch64/vec.c @@ -337,7 +337,7 @@ EXPORT int64_t int4Bit(uint8_t* query, uint8_t* doc, int64_t offset, int length) return dot_q0 + (dot_q1 << 1) + (dot_q2 << 2) + (dot_q3 << 3); } -EXPORT void int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float32_t* scores, int count, int length) { +EXPORT void int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float_t* scores, int count, int length) { const size_t stride = (length / 8) * 8; const uint8_t* query_j0 = query; const uint8_t* query_j1 = query + length; From 0c20cf2da4694507b5fb9aa400bbb719ff73c4ed Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Mon, 15 Sep 2025 13:04:15 +0100 Subject: [PATCH 10/15] better this way --- .../MemorySegmentES91OSQVectorsScorer.java | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java index f81413641aafd..15844e15d8540 100644 --- a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java @@ -21,7 +21,17 @@ /** Native scorer for quantized vectors stored as an {@link IndexInput}. */ public final class MemorySegmentES91OSQVectorsScorer extends MemorySegmentES91PanamaOSQVectorsScorer { - private static final boolean NATIVE_SUPPORTED = NativeAccess.instance().getVectorSimilarityFunctions().isPresent(); + private static final boolean NATIVE_SUPPORTED; + static { + boolean nativeSupported = NativeAccess.instance().getVectorSimilarityFunctions().isPresent(); + if (nativeSupported) { + MemorySegment query = MemorySegment.ofArray(new byte[4]); + MemorySegment doc = MemorySegment.ofArray(new byte[1]); + long qScore = int4BitDotProduct(query, doc, 0L, 1); + nativeSupported = qScore != -1; + } + NATIVE_SUPPORTED = nativeSupported; + } public MemorySegmentES91OSQVectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) { super(in, dimensions, memorySegment); @@ -41,9 +51,6 @@ private long nativeQuantizeScore(byte[] q) throws IOException { long initialOffset = in.getFilePointer(); MemorySegment query = MemorySegment.ofArray(q); long qScore = int4BitDotProduct(query, memorySegment, initialOffset, length); - if (qScore == -1) { - return panamaQuantizeScore(q); - } in.skipBytes(length); return qScore; } @@ -64,11 +71,7 @@ private void nativeQuantizeScoreBulk(byte[] q, int count, float[] scores) throws MemorySegment query = MemorySegment.ofArray(q); MemorySegment scoresSegment = MemorySegment.ofArray(scores); int4BitDotProductBulk(query, memorySegment, initialOffset, scoresSegment, count, length); - if (scores[0] == -1) { - panamaQuantizeScoreBulk(q, count, scores); - } else { - in.skipBytes(count * length); - } + in.skipBytes(count * length); } @Override From cc9d3db93350b0fd82f6babc785900006cee9259 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Thu, 25 Sep 2025 12:02:07 +0200 Subject: [PATCH 11/15] iter --- .../benchmark/vector/OSQScorerBenchmark.java | 30 ++-- libs/native/libraries/build.gradle | 2 +- libs/simdvec/native/src/vec/c/aarch64/vec.c | 142 +++++++++++------- .../ES91OSQVectorScorerTests.java | 6 +- 4 files changed, 103 insertions(+), 77 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java index 70c19abc0a495..52ff1772af5dc 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java @@ -150,11 +150,11 @@ public void scoreFromMemorySegmentOnlyVectorMmapVect(Blackhole bh) throws IOExce // scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios); // } - @Benchmark - @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public void scoreFromMemorySegmentOnlyVectorNiofsVect(Blackhole bh) throws IOException { - scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios); - } +// @Benchmark +// @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) +// public void scoreFromMemorySegmentOnlyVectorNiofsVect(Blackhole bh) throws IOException { +// scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios); +// } private void scoreFromMemorySegmentOnlyVector(Blackhole bh, IndexInput in, ES91OSQVectorsScorer scorer) throws IOException { for (int j = 0; j < numQueries; j++) { @@ -197,11 +197,11 @@ public void scoreFromMemorySegmentOnlyVectorBulkMmapVect(Blackhole bh) throws IO // scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios); // } - @Benchmark - @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public void scoreFromMemorySegmentOnlyVectorBulkNiofsVect(Blackhole bh) throws IOException { - scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios); - } +// @Benchmark +// @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) +// public void scoreFromMemorySegmentOnlyVectorBulkNiofsVect(Blackhole bh) throws IOException { +// scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios); +// } private void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh, IndexInput in, ES91OSQVectorsScorer scorer) throws IOException { for (int j = 0; j < numQueries; j++) { @@ -246,11 +246,11 @@ public void scoreFromMemorySegmentAllBulkMmapVect(Blackhole bh) throws IOExcepti // scoreFromMemorySegmentAllBulk(bh, inNiofs, scorerNfios); // } - @Benchmark - @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public void scoreFromMemorySegmentAllBulkNiofsVect(Blackhole bh) throws IOException { - scoreFromMemorySegmentAllBulk(bh, inNiofs, scorerNfios); - } +// @Benchmark +// @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) +// public void scoreFromMemorySegmentAllBulkNiofsVect(Blackhole bh) throws IOException { +// scoreFromMemorySegmentAllBulk(bh, inNiofs, scorerNfios); +// } private void scoreFromMemorySegmentAllBulk(Blackhole bh, IndexInput in, ES91OSQVectorsScorer scorer) throws IOException { for (int j = 0; j < numQueries; j++) { diff --git a/libs/native/libraries/build.gradle b/libs/native/libraries/build.gradle index 4d94ad6e20c73..01498a65136db 100644 --- a/libs/native/libraries/build.gradle +++ b/libs/native/libraries/build.gradle @@ -52,7 +52,7 @@ dependencies { libs "org.elasticsearch:zstd:${zstdVersion}:linux-aarch64" libs "org.elasticsearch:zstd:${zstdVersion}:linux-x86-64" libs "org.elasticsearch:zstd:${zstdVersion}:windows-x86-64" - libs "org.elasticsearch:vec:${vecVersion}@zip" // temporarily comment this out, if testing a locally built native lib + //libs "org.elasticsearch:vec:${vecVersion}@zip" // temporarily comment this out, if testing a locally built native lib } def extractLibs = tasks.register('extractLibs', Copy) { diff --git a/libs/simdvec/native/src/vec/c/aarch64/vec.c b/libs/simdvec/native/src/vec/c/aarch64/vec.c index a7b23a6137764..56d4c335fe066 100644 --- a/libs/simdvec/native/src/vec/c/aarch64/vec.c +++ b/libs/simdvec/native/src/vec/c/aarch64/vec.c @@ -300,79 +300,105 @@ EXPORT float sqrf32(const float *a, const float *b, size_t elementCount) { return result; } -EXPORT int64_t int4Bit(uint8_t* query, uint8_t* doc, int64_t offset, int length) { - const size_t stride = (length / 8) * 8; - uint64_t dot_q0 = 0; - uint64_t dot_q1 = 0; - uint64_t dot_q2 = 0; - uint64_t dot_q3 = 0; - const uint8_t* doc_idx = doc + offset; - const uint8_t* query_j0 = query; - const uint8_t* query_j1 = query + length; - const uint8_t* query_j2 = query + 2 * length; - const uint8_t* query_j3 = query + 3 * length; - int i = 0; - for (; i < stride; i += 8) { +static inline int64_t int4Bit_inner(uint8_t* query, uint8_t* doc, int64_t offset, int length) { + uint64_t sum0 = 0; + uint64_t sum1 = 0; + uint64_t sum2 = 0; + uint64_t sum3 = 0; + uint64_t chunk_size = 16; + + const uint8_t* doc_idx = doc + offset; + const uint8_t* query_j0 = query; + const uint8_t* query_j1 = query + length; + const uint8_t* query_j2 = query + 2 * length; + const uint8_t* query_j3 = query + 3 * length; + + int i = 0; + + if (length >= chunk_size) + { + uint64_t iters = length / chunk_size; + uint64x2_t sumP0 = vcombine_u64(vcreate_u64(0), vcreate_u64(0)); + uint64x2_t sumP1 = sumP0; + uint64x2_t sumP2 = sumP0; + uint64x2_t sumP3 = sumP0; + uint8x16_t zero = vcombine_u8(vcreate_u8(0), vcreate_u8(0)); + int j = 0; + + do + { + uint8x16_t qDot0 = zero; + uint8x16_t qDot1 = zero; + uint8x16_t qDot2 = zero; + uint8x16_t qDot3 = zero; + + /* + * After every 31 iterations we need to add the + * temporary sums (qDot0, qDot1, qDot2, qDot3) to the total sum. + * We must ensure that the temporary sums <= 255 + * and 31 * 8 bits = 248 which is OK. + */ + uint64_t limit = (j + 31 < iters) ? j + 31 : iters; + + for (; j < limit; j++, i+= chunk_size) + { + const uint8x16_t qv0 = vld1q_u8(query_j0 + i); + const uint8x16_t qv1 = vld1q_u8(query_j1 + i); + const uint8x16_t qv2 = vld1q_u8(query_j2 + i); + const uint8x16_t qv3 = vld1q_u8(query_j3 + i); + const uint8x16_t yv = vld1q_u8(doc_idx + i); + + qDot0 = vaddq_u8(qDot0, vcntq_u8(vandq_u8(qv0,yv))); + qDot1 = vaddq_u8(qDot1, vcntq_u8(vandq_u8(qv1,yv))); + qDot2 = vaddq_u8(qDot2, vcntq_u8(vandq_u8(qv2,yv))); + qDot3 = vaddq_u8(qDot3, vcntq_u8(vandq_u8(qv3,yv))); + } + + sumP0 = vpadalq_u32(sumP0, vpaddlq_u16(vpaddlq_u8(qDot0))); + sumP1 = vpadalq_u32(sumP1, vpaddlq_u16(vpaddlq_u8(qDot1))); + sumP2 = vpadalq_u32(sumP2, vpaddlq_u16(vpaddlq_u8(qDot2))); + sumP3 = vpadalq_u32(sumP3, vpaddlq_u16(vpaddlq_u8(qDot3))); + } + while (j < iters); + + sum0 += sumP0[0] + sumP0[1]; + sum1 += sumP1[0] + sumP1[1]; + sum2 += sumP2[0] + sumP2[1]; + sum3 += sumP3[0] + sumP3[1]; + } + + for (; i < length - 7; i += 8) { const uint64_t qv0 = *(const uint64_t*)(query_j0 + i); const uint64_t qv1 = *(const uint64_t*)(query_j1 + i); const uint64_t qv2 = *(const uint64_t*)(query_j2 + i); const uint64_t qv3 = *(const uint64_t*)(query_j3 + i); const uint64_t yv = *(const uint64_t*)(doc_idx + i); - dot_q0 += __builtin_popcountll(qv0 & yv); - dot_q1 += __builtin_popcountll(qv1 & yv); - dot_q2 += __builtin_popcountll(qv2 & yv); - dot_q3 += __builtin_popcountll(qv3 & yv); + sum0 += __builtin_popcountll(qv0 & yv); + sum1 += __builtin_popcountll(qv1 & yv); + sum2 += __builtin_popcountll(qv2 & yv); + sum3 += __builtin_popcountll(qv3 & yv); } + for (; i < length; i++) { const uint8_t qv0 = *(query_j0 + i); const uint8_t qv1 = *(query_j1 + i); const uint8_t qv2 = *(query_j2 + i); const uint8_t qv3 = *(query_j3 + i); - const uint8_t yv = *(doc_idx + i); - dot_q0 += __builtin_popcountll(qv0 & yv); - dot_q1 += __builtin_popcountll(qv1 & yv); - dot_q2 += __builtin_popcountll(qv2 & yv); - dot_q3 += __builtin_popcountll(qv3 & yv); + const uint8_t yv = *(doc_idx + i); + sum0 += __builtin_popcountll(qv0 & yv); + sum1 += __builtin_popcountll(qv1 & yv); + sum2 += __builtin_popcountll(qv2 & yv); + sum3 += __builtin_popcountll(qv3 & yv); } - return dot_q0 + (dot_q1 << 1) + (dot_q2 << 2) + (dot_q3 << 3); + return sum0 + (sum1 << 1) + (sum2 << 2) + (sum3 << 3); +} + +EXPORT int64_t int4Bit(uint8_t* query, uint8_t* doc, int64_t offset, int length) { + return int4Bit_inner(query, doc, offset, length); } EXPORT void int4BitBulk(uint8_t* query, uint8_t* doc, int64_t offset, float_t* scores, int count, int length) { - const size_t stride = (length / 8) * 8; - const uint8_t* query_j0 = query; - const uint8_t* query_j1 = query + length; - const uint8_t* query_j2 = query + 2 * length; - const uint8_t* query_j3 = query + 3 * length; - // assumption that the query bits are 4, and doc bits are singular for (size_t idx = 0; idx < count; idx++) { - uint64_t dot_q0 = 0; - uint64_t dot_q1 = 0; - uint64_t dot_q2 = 0; - uint64_t dot_q3 = 0; - const uint8_t* doc_idx = doc + offset + idx * length; - int i = 0; - for (; i < stride; i += 8) { - const uint64_t qv0 = *(const uint64_t*)(query_j0 + i); - const uint64_t qv1 = *(const uint64_t*)(query_j1 + i); - const uint64_t qv2 = *(const uint64_t*)(query_j2 + i); - const uint64_t qv3 = *(const uint64_t*)(query_j3 + i); - const uint64_t yv = *(const uint64_t*)(doc_idx + i); - dot_q0 += __builtin_popcountll(qv0 & yv); - dot_q1 += __builtin_popcountll(qv1 & yv); - dot_q2 += __builtin_popcountll(qv2 & yv); - dot_q3 += __builtin_popcountll(qv3 & yv); - } - for (; i < length; i++) { - const uint8_t qv0 = *(query_j0 + i); - const uint8_t qv1 = *(query_j1 + i); - const uint8_t qv2 = *(query_j2 + i); - const uint8_t qv3 = *(query_j3 + i); - const uint8_t yv = *(doc_idx + i); - dot_q0 += __builtin_popcountll(qv0 & yv); - dot_q1 += __builtin_popcountll(qv1 & yv); - dot_q2 += __builtin_popcountll(qv2 & yv); - dot_q3 += __builtin_popcountll(qv3 & yv); - } - scores[idx] = (float32_t)(dot_q0 + (dot_q1 << 1) + (dot_q2 << 2) + (dot_q3 << 3)); + scores[idx] = int4Bit_inner(query, doc, offset + idx * length, length); } } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java index 4a16f03c66c95..6f5130c2df0ca 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java @@ -28,7 +28,7 @@ public class ES91OSQVectorScorerTests extends BaseVectorizationTests { public void testQuantizeScore() throws Exception { - final int dimensions = random().nextInt(1, 2000); + final int dimensions = random().nextInt(1, 10_000); final int length = BQVectorUtils.discretize(dimensions, 64) / 8; final int numVectors = random().nextInt(1, 100); final byte[] vector = new byte[length]; @@ -58,7 +58,7 @@ public void testQuantizeScore() throws Exception { } public void testScore() throws Exception { - final int maxDims = random().nextInt(1, 1000) * 2; + final int maxDims = random().nextInt(1, 10_000) * 2; final int dimensions = random().nextInt(1, maxDims); final int length = BQVectorUtils.discretize(dimensions, 64) / 8; final int numVectors = random().nextInt(10, 50); @@ -157,7 +157,7 @@ public void testScore() throws Exception { } public void testScoreBulk() throws Exception { - final int maxDims = random().nextInt(1, 1000) * 2; + final int maxDims = random().nextInt(1, 10_000) * 2; final int dimensions = random().nextInt(1, maxDims); final int length = BQVectorUtils.discretize(dimensions, 64) / 8; final int numVectors = ES91OSQVectorsScorer.BULK_SIZE * random().nextInt(1, 10); From 274e9ba25f70149c26d341e4d7ae41b414d18a1d Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 25 Sep 2025 10:13:36 +0000 Subject: [PATCH 12/15] [CI] Update transport version definitions --- server/src/main/resources/transport/upper_bounds/8.18.csv | 2 +- server/src/main/resources/transport/upper_bounds/8.19.csv | 2 +- server/src/main/resources/transport/upper_bounds/9.0.csv | 2 +- server/src/main/resources/transport/upper_bounds/9.1.csv | 2 +- server/src/main/resources/transport/upper_bounds/9.2.csv | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/server/src/main/resources/transport/upper_bounds/8.18.csv b/server/src/main/resources/transport/upper_bounds/8.18.csv index 4eb5140004ea6..266bfbbd3bf78 100644 --- a/server/src/main/resources/transport/upper_bounds/8.18.csv +++ b/server/src/main/resources/transport/upper_bounds/8.18.csv @@ -1 +1 @@ -initial_elasticsearch_8_18_6,8840008 +transform_check_for_dangling_tasks,8840011 diff --git a/server/src/main/resources/transport/upper_bounds/8.19.csv b/server/src/main/resources/transport/upper_bounds/8.19.csv index 476468b203875..3600b3f8c633a 100644 --- a/server/src/main/resources/transport/upper_bounds/8.19.csv +++ b/server/src/main/resources/transport/upper_bounds/8.19.csv @@ -1 +1 @@ -initial_elasticsearch_8_19_3,8841067 +transform_check_for_dangling_tasks,8841070 diff --git a/server/src/main/resources/transport/upper_bounds/9.0.csv b/server/src/main/resources/transport/upper_bounds/9.0.csv index f8f50cc6d7839..c11e6837bb813 100644 --- a/server/src/main/resources/transport/upper_bounds/9.0.csv +++ b/server/src/main/resources/transport/upper_bounds/9.0.csv @@ -1 +1 @@ -initial_elasticsearch_9_0_6,9000015 +transform_check_for_dangling_tasks,9000018 diff --git a/server/src/main/resources/transport/upper_bounds/9.1.csv b/server/src/main/resources/transport/upper_bounds/9.1.csv index 5a65f2e578156..80b97d85f7511 100644 --- a/server/src/main/resources/transport/upper_bounds/9.1.csv +++ b/server/src/main/resources/transport/upper_bounds/9.1.csv @@ -1 +1 @@ -initial_elasticsearch_9_1_4,9112007 +transform_check_for_dangling_tasks,9112009 diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv index 49360d5e62d69..2c15e0254cbe8 100644 --- a/server/src/main/resources/transport/upper_bounds/9.2.csv +++ b/server/src/main/resources/transport/upper_bounds/9.2.csv @@ -1 +1 @@ -inference_api_eis_diagnostics,9156000 +transform_check_for_dangling_tasks,9170000 From fbee06a9650efddcd8d4dde793182ecc9312896d Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 25 Sep 2025 10:13:54 +0000 Subject: [PATCH 13/15] [CI] Auto commit changes from spotless --- .../benchmark/vector/OSQScorerBenchmark.java | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java index 52ff1772af5dc..17929d94353b9 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java @@ -150,11 +150,11 @@ public void scoreFromMemorySegmentOnlyVectorMmapVect(Blackhole bh) throws IOExce // scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios); // } -// @Benchmark -// @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) -// public void scoreFromMemorySegmentOnlyVectorNiofsVect(Blackhole bh) throws IOException { -// scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios); -// } + // @Benchmark + // @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + // public void scoreFromMemorySegmentOnlyVectorNiofsVect(Blackhole bh) throws IOException { + // scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios); + // } private void scoreFromMemorySegmentOnlyVector(Blackhole bh, IndexInput in, ES91OSQVectorsScorer scorer) throws IOException { for (int j = 0; j < numQueries; j++) { @@ -197,11 +197,11 @@ public void scoreFromMemorySegmentOnlyVectorBulkMmapVect(Blackhole bh) throws IO // scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios); // } -// @Benchmark -// @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) -// public void scoreFromMemorySegmentOnlyVectorBulkNiofsVect(Blackhole bh) throws IOException { -// scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios); -// } + // @Benchmark + // @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + // public void scoreFromMemorySegmentOnlyVectorBulkNiofsVect(Blackhole bh) throws IOException { + // scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios); + // } private void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh, IndexInput in, ES91OSQVectorsScorer scorer) throws IOException { for (int j = 0; j < numQueries; j++) { @@ -246,11 +246,11 @@ public void scoreFromMemorySegmentAllBulkMmapVect(Blackhole bh) throws IOExcepti // scoreFromMemorySegmentAllBulk(bh, inNiofs, scorerNfios); // } -// @Benchmark -// @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) -// public void scoreFromMemorySegmentAllBulkNiofsVect(Blackhole bh) throws IOException { -// scoreFromMemorySegmentAllBulk(bh, inNiofs, scorerNfios); -// } + // @Benchmark + // @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + // public void scoreFromMemorySegmentAllBulkNiofsVect(Blackhole bh) throws IOException { + // scoreFromMemorySegmentAllBulk(bh, inNiofs, scorerNfios); + // } private void scoreFromMemorySegmentAllBulk(Blackhole bh, IndexInput in, ES91OSQVectorsScorer scorer) throws IOException { for (int j = 0; j < numQueries; j++) { From e21444f8ada9386affa8d22839b458cc89cc56b4 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 2 Oct 2025 07:41:26 +0000 Subject: [PATCH 14/15] [CI] Update transport version definitions --- server/src/main/resources/transport/upper_bounds/9.2.csv | 2 +- server/src/main/resources/transport/upper_bounds/9.3.csv | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 server/src/main/resources/transport/upper_bounds/9.3.csv diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv index 2c15e0254cbe8..2147eab66c207 100644 --- a/server/src/main/resources/transport/upper_bounds/9.2.csv +++ b/server/src/main/resources/transport/upper_bounds/9.2.csv @@ -1 +1 @@ -transform_check_for_dangling_tasks,9170000 +initial_9.2.0,9185000 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv new file mode 100644 index 0000000000000..2147eab66c207 --- /dev/null +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -0,0 +1 @@ +initial_9.2.0,9185000 From e2876328b30b59d06e9b026813cf819e465cd86d Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 17 Dec 2025 23:21:38 +0000 Subject: [PATCH 15/15] [CI] Update transport version definitions --- server/src/main/resources/transport/upper_bounds/8.18.csv | 2 +- server/src/main/resources/transport/upper_bounds/8.19.csv | 2 +- server/src/main/resources/transport/upper_bounds/9.0.csv | 2 +- server/src/main/resources/transport/upper_bounds/9.1.csv | 2 +- server/src/main/resources/transport/upper_bounds/9.2.csv | 2 +- server/src/main/resources/transport/upper_bounds/9.3.csv | 2 +- server/src/main/resources/transport/upper_bounds/9.4.csv | 1 + 7 files changed, 7 insertions(+), 6 deletions(-) create mode 100644 server/src/main/resources/transport/upper_bounds/9.4.csv diff --git a/server/src/main/resources/transport/upper_bounds/8.18.csv b/server/src/main/resources/transport/upper_bounds/8.18.csv index 266bfbbd3bf78..515078281318d 100644 --- a/server/src/main/resources/transport/upper_bounds/8.18.csv +++ b/server/src/main/resources/transport/upper_bounds/8.18.csv @@ -1 +1 @@ -transform_check_for_dangling_tasks,8840011 +initial_8.18.9,8840012 diff --git a/server/src/main/resources/transport/upper_bounds/8.19.csv b/server/src/main/resources/transport/upper_bounds/8.19.csv index 3600b3f8c633a..9fb7e0040a06f 100644 --- a/server/src/main/resources/transport/upper_bounds/8.19.csv +++ b/server/src/main/resources/transport/upper_bounds/8.19.csv @@ -1 +1 @@ -transform_check_for_dangling_tasks,8841070 +jina_ai_embedding_dimensions_support_added,8841078 diff --git a/server/src/main/resources/transport/upper_bounds/9.0.csv b/server/src/main/resources/transport/upper_bounds/9.0.csv index c11e6837bb813..4c1383d7ff6ad 100644 --- a/server/src/main/resources/transport/upper_bounds/9.0.csv +++ b/server/src/main/resources/transport/upper_bounds/9.0.csv @@ -1 +1 @@ -transform_check_for_dangling_tasks,9000018 +initial_9.0.9,9000019 diff --git a/server/src/main/resources/transport/upper_bounds/9.1.csv b/server/src/main/resources/transport/upper_bounds/9.1.csv index 80b97d85f7511..00fbb31455087 100644 --- a/server/src/main/resources/transport/upper_bounds/9.1.csv +++ b/server/src/main/resources/transport/upper_bounds/9.1.csv @@ -1 +1 @@ -transform_check_for_dangling_tasks,9112009 +jina_ai_embedding_dimensions_support_added,9112017 diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv index 2147eab66c207..35b85ca9f8a17 100644 --- a/server/src/main/resources/transport/upper_bounds/9.2.csv +++ b/server/src/main/resources/transport/upper_bounds/9.2.csv @@ -1 +1 @@ -initial_9.2.0,9185000 +jina_ai_embedding_dimensions_support_added,9185014 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv index 2147eab66c207..94a47e0878c87 100644 --- a/server/src/main/resources/transport/upper_bounds/9.3.csv +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -1 +1 @@ -initial_9.2.0,9185000 +initial_9.3.0,9250000 diff --git a/server/src/main/resources/transport/upper_bounds/9.4.csv b/server/src/main/resources/transport/upper_bounds/9.4.csv new file mode 100644 index 0000000000000..94a47e0878c87 --- /dev/null +++ b/server/src/main/resources/transport/upper_bounds/9.4.csv @@ -0,0 +1 @@ +initial_9.3.0,9250000