Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
2712841
Int4 to bit1 dot product functions
ldematte Dec 18, 2025
085b4ec
Wiring it up for DiskBBQ
ldematte Dec 18, 2025
3ba7513
Parametrize VectorScorerOSQBenchmark
ldematte Dec 18, 2025
564b1fc
Merge remote-tracking branch 'upstream/main' into simd/int4-bit1-dot
ldematte Jan 7, 2026
098e70e
Fixes post-merge
ldematte Jan 7, 2026
ec0029d
Copy/paste AMD native implementation(s)
ldematte Jan 8, 2026
34b81c2
Merge + Fixes
ldematte Jan 8, 2026
b07929e
Renaming, fix parameter ordering
ldematte Jan 8, 2026
9e29360
Merge remote-tracking branch 'upstream/main' into simd/int4-bit1-dot
ldematte Jan 9, 2026
c81782a
Merge remote-tracking branch 'upstream/main' into simd/int4-bit1-dot
ldematte Jan 9, 2026
352154a
AVX2 simple optimization
ldematte Jan 9, 2026
2c6dcff
AVX2 more advanced optimization
ldematte Jan 9, 2026
bf4554c
Spotelss + enable native
ldematte Jan 9, 2026
99cef84
Merge branch 'simd/int4-bit1-dot' of github.com:ldematte/elasticsearc…
ldematte Jan 9, 2026
6a8c9f8
Small avx2 improvement
ldematte Jan 9, 2026
4b511ff
Merge branch 'simd/int4-bit1-dot' of github.com:ldematte/elasticsearc…
ldematte Jan 9, 2026
dda3336
fix
ldematte Jan 9, 2026
563525b
AVX-512 optimization
ldematte Jan 9, 2026
f693d3d
Merge branch 'simd/int4-bit1-dot' of github.com:ldematte/elasticsearc…
ldematte Jan 9, 2026
45d88e2
multi-mh for MemorySegment optimization
ldematte Jan 13, 2026
3ff66f1
Merge branch 'simd/int4-bit1-dot' of github.com:ldematte/elasticsearc…
ldematte Jan 13, 2026
4b9562b
AVX-512 optimizations: masking, bulk inline, more prefetch/unrolling …
ldematte Jan 13, 2026
2dd13d9
revert AVX-512 masking, other small fixes
ldematte Jan 13, 2026
96bd88b
Merge remote-tracking branch 'upstream/main' into simd/int4-bit1-dot
ldematte Jan 13, 2026
71cde54
AVX2 tweak: more 4x prefetching (instead of 2x)
ldematte Jan 14, 2026
ab869a4
Restrict vec_caps=2 on x64 to processors supporting AVX512-VNNI and -…
ldematte Jan 15, 2026
09757fc
Merge remote-tracking branch 'upstream/main' into simd/int4-bit1-dot
ldematte Jan 16, 2026
cce3188
Add JDKVectorLibraryInt4Tests
ldematte Jan 16, 2026
062233a
[CI] Auto commit changes from spotless
Jan 16, 2026
e423a11
Merge remote-tracking branch 'upstream/main' into simd/int4-bit1-dot
ldematte Jan 16, 2026
ce11bee
spotless
ldematte Jan 16, 2026
4b758e7
Merge branch 'simd/int4-bit1-dot' of github.com:ldematte/elasticsearc…
ldematte Jan 16, 2026
cf04b0e
update simdvec library version after publish
ldematte Jan 16, 2026
70a8aa1
Update docs/changelog/140264.yaml
ldematte Jan 16, 2026
a6cf988
Merge branch 'simd/int4-bit1-dot' of github.com:ldematte/elasticsearc…
ldematte Jan 16, 2026
6eeb9b4
update changelog
ldematte Jan 16, 2026
9b41ed7
Merge remote-tracking branch 'upstream/main' into simd/int4-bit1-dot
ldematte Jan 16, 2026
f4e69dc
Update docs/changelog/140264.yaml
ldematte Jan 19, 2026
d90eae3
Merge remote-tracking branch 'upstream/main' into simd/int4-bit1-dot
ldematte Jan 19, 2026
1b06db9
Merge branch 'simd/int4-bit1-dot' of github.com:ldematte/elasticsearc…
ldematte Jan 19, 2026
da40176
Merge remote-tracking branch 'upstream/main' into simd/int4-bit1-dot
ldematte Jan 19, 2026
517b59f
ARM Neon optimized version
ldematte Jan 20, 2026
2fbd6c0
Bump vec version after publish
ldematte Jan 20, 2026
be9576c
spotless
ldematte Jan 20, 2026
8b04e31
Merge branch 'main' into simd/int4-bit1-dot
ldematte Jan 20, 2026
15b4189
PR comments
ldematte Jan 21, 2026
e006aa7
Merge branch 'simd/int4-bit1-dot' of github.com:ldematte/elasticsearc…
ldematte Jan 21, 2026
e8bfc54
Merge remote-tracking branch 'upstream/main' into simd/int4-bit1-dot
ldematte Jan 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/140264.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 140264
summary: New optimized (native) functions for BBQ Int4 to 1-bit scoring
area: Vector Search
type: enhancement
issues:
- 128523
2 changes: 1 addition & 1 deletion libs/native/libraries/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ configurations {
}

var zstdVersion = "1.5.7"
var vecVersion = "1.0.22"
var vecVersion = "1.0.24"

repositories {
exclusiveContent {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,55 @@ public interface VectorSimilarityFunctions {
*/
MethodHandle dotProductHandle7uBulkWithOffsets();

/**
* Produces a method handle returning the dot product of an int4 (half-byte) vector and
* a bit vector (one bit per element)
*
* <p> The type of the method handle will have {@code long} as return type, The type of
* its first and second arguments will be {@code MemorySegment}, whose contents is the
* vector data bytes. The third argument is the length of the vector data.
*/
MethodHandle dotProductHandleI1I4();

/**
* Produces a method handle which computes the dot product of several vectors.
* This bulk operation can be used to compute the dot product between a
* single int4 query vector and a number of bit vectors (one bit per element),
*
* <p> The type of the method handle will have {@code void} as return type. The type of
* its first and second arguments will be {@code MemorySegment}, the former contains the
* vector data bytes for several vectors, while the latter just a single vector. The
* type of the third argument is an int, representing the dimensions of each vector. The
* type of the fourth argument is an int, representing the number of vectors in the
* first argument. The type of the final argument is a MemorySegment, into which the
* computed dot product float values will be stored.
*/
MethodHandle dotProductHandleI1I4Bulk();

/**
* Produces a method handle which computes the dot product of several vectors.
* This bulk operation can be used to compute the dot product between a single int4 query
* vector and a subset of vectors from a dataset (array of 1-bit vectors). Each
* vector to include in the operation is identified by an offset inside the dataset.
*
* <p> The type of the method handle will have {@code void} as return type. The type of
* its arguments will be:
* <ol>
* <li>a {@code MemorySegment} containing the vector data bytes for several vectors;
* in other words, a contiguous array of vectors</li>
* <li>a {@code MemorySegment} containing the vector data bytes for a single ("query") vector</li>
* <li>an {@code int}, representing the dimensions of each vector</li>
* <li>an {@code int}, representing the width (in bytes) of each vector. Or, in other words,
* the distance in bytes between two vectors inside the first param's {@code MemorySegment}</li>
* <li>a {@code MemorySegment} containing the indices of the vectors inside the first param's array
* on which we'll compute the dot product</li>
* <li>an {@code int}, representing the number of vectors for which we'll compute the dot product
* (which is equal to the size - in number of elements - of the 5th and 7th {@code MemorySegment}s)</li>
* <li>a {@code MemorySegment}, into which the computed dot product float values will be stored</li>
* </ol>
*/
MethodHandle dotProductHandleI1I4BulkWithOffsets();

/**
* Produces a method handle returning the square distance of byte (unsigned int7) vectors.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.channels.FileChannel;
import java.util.Objects;

import static java.lang.foreign.ValueLayout.ADDRESS;
import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
import static java.lang.foreign.ValueLayout.JAVA_INT;
import static java.lang.foreign.ValueLayout.JAVA_LONG;
import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle;
import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.functionAddressOrNull;

Expand All @@ -39,6 +39,10 @@ public final class JdkVectorLibrary implements VectorLibrary {
static final MethodHandle dot7uBulk$mh;
static final MethodHandle dot7uBulkWithOffsets$mh;

static final MethodHandle doti1i4$mh;
static final MethodHandle doti1i4Bulk$mh;
static final MethodHandle doti1i4BulkWithOffsets$mh;

static final MethodHandle sqr7u$mh;
static final MethodHandle sqr7uBulk$mh;
static final MethodHandle sqr7uBulkWithOffsets$mh;
Expand Down Expand Up @@ -93,6 +97,7 @@ private static MethodHandle bindFunction(String functionName, int capability, Fu
logger.info("vec_caps=" + caps);
if (caps > 0) {
FunctionDescriptor intSingle = FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT);
FunctionDescriptor longSingle = FunctionDescriptor.of(JAVA_LONG, ADDRESS, ADDRESS, JAVA_INT);
FunctionDescriptor floatSingle = FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT);
FunctionDescriptor bulk = FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS);
FunctionDescriptor bulkOffsets = FunctionDescriptor.ofVoid(
Expand All @@ -109,6 +114,10 @@ private static MethodHandle bindFunction(String functionName, int capability, Fu
dot7uBulk$mh = bindFunction("vec_dot7u_bulk", caps, bulk);
dot7uBulkWithOffsets$mh = bindFunction("vec_dot7u_bulk_offsets", caps, bulkOffsets);

doti1i4$mh = bindFunction("vec_dot_int1_int4", caps, longSingle);
doti1i4Bulk$mh = bindFunction("vec_dot_int1_int4_bulk", caps, bulk);
doti1i4BulkWithOffsets$mh = bindFunction("vec_dot_int1_int4_bulk_offsets", caps, bulkOffsets);

sqr7u$mh = bindFunction("vec_sqr7u", caps, intSingle);
sqr7uBulk$mh = bindFunction("vec_sqr7u_bulk", caps, bulk);
sqr7uBulkWithOffsets$mh = bindFunction("vec_sqr7u_bulk_offsets", caps, bulkOffsets);
Expand All @@ -131,6 +140,9 @@ private static MethodHandle bindFunction(String functionName, int capability, Fu
dot7u$mh = null;
dot7uBulk$mh = null;
dot7uBulkWithOffsets$mh = null;
doti1i4$mh = null;
doti1i4Bulk$mh = null;
doti1i4BulkWithOffsets$mh = null;
sqr7u$mh = null;
sqr7uBulk$mh = null;
sqr7uBulkWithOffsets$mh = null;
Expand Down Expand Up @@ -163,7 +175,8 @@ private static final class JdkVectorSimilarityFunctions implements VectorSimilar
* <p>
* Vector data is consumed by native functions directly via a pointer to contiguous memory, represented in FFI by
* {@link MemorySegment}s, which safely encapsulate a memory location, off-heap or on-heap.
* We mainly use <b>shared</b> MemorySegments for off-heap vectors (via {@link Arena#ofShared} or via {@link FileChannel#map}).
* We mainly use <b>shared</b> MemorySegments for off-heap vectors (via {@link Arena#ofShared} or via
* {@link java.nio.channels.FileChannel#map}).
* <p>
* Shared MemorySegments have a built-in check for liveness when accessed by native functions, implemented by JIT adding some
* additional instructions before/after the native function is actually called.
Expand Down Expand Up @@ -267,6 +280,44 @@ static void dotProduct7uBulkWithOffsets(
dot7uBulkWithOffsets(a, b, length, pitch, offsets, count, result);
}

/**
* Computes the dot product of a given int4 vector with a give bit vector (1 bit per element).
*
* @param a address of the bit vector
* @param query address of the int4 vector
* @param length the vector dimensions
*/
static long dotProductI1I4(MemorySegment a, MemorySegment query, int length) {
Objects.checkFromIndexSize(0, length * 4L, (int) query.byteSize());
Objects.checkFromIndexSize(0, length, (int) a.byteSize());
return callSingleDistanceLong(doti1i4$mh, a, query, length);
}

static void dotProductI1I4Bulk(
MemorySegment dataset,
MemorySegment query,
int datasetVectorLengthInBytes,
int count,
MemorySegment result
) {
Objects.checkFromIndexSize(0, datasetVectorLengthInBytes * count, (int) dataset.byteSize());
Objects.checkFromIndexSize(0, datasetVectorLengthInBytes * 4L, (int) query.byteSize());
Objects.checkFromIndexSize(0, count * Float.BYTES, (int) result.byteSize());
doti1i4Bulk(dataset, query, datasetVectorLengthInBytes, count, result);
}

static void dotProductI1I4BulkWithOffsets(
MemorySegment a,
MemorySegment b,
int length,
int pitch,
MemorySegment offsets,
int count,
MemorySegment result
) {
doti1i4BulkWithOffsets(a, b, length, pitch, offsets, count, result);
}

/**
* Computes the square distance of given unsigned int7 byte vectors.
*
Expand Down Expand Up @@ -399,6 +450,30 @@ private static void dot7uBulkWithOffsets(
}
}

private static void doti1i4Bulk(MemorySegment a, MemorySegment query, int length, int count, MemorySegment result) {
try {
doti1i4Bulk$mh.invokeExact(a, query, length, count, result);
} catch (Throwable t) {
throw new AssertionError(t);
}
}

private static void doti1i4BulkWithOffsets(
MemorySegment a,
MemorySegment query,
int length,
int pitch,
MemorySegment offsets,
int count,
MemorySegment result
) {
try {
doti1i4BulkWithOffsets$mh.invokeExact(a, query, length, pitch, offsets, count, result);
} catch (Throwable t) {
throw new AssertionError(t);
}
}

private static void sqr7uBulk(MemorySegment a, MemorySegment b, int length, int count, MemorySegment result) {
try {
sqr7uBulk$mh.invokeExact(a, b, length, count, result);
Expand Down Expand Up @@ -474,6 +549,11 @@ private static void sqrf32BulkWithOffsets(
static final MethodHandle DOT_HANDLE_7U;
static final MethodHandle DOT_HANDLE_7U_BULK;
static final MethodHandle DOT_HANDLE_7U_BULK_WITH_OFFSETS;

static final MethodHandle DOT_HANDLE_I1I4;
static final MethodHandle DOT_HANDLE_I1I4_BULK;
static final MethodHandle DOT_HANDLE_I1I4_BULK_WITH_OFFSETS;

static final MethodHandle SQR_HANDLE_7U;
static final MethodHandle SQR_HANDLE_7U_BULK;
static final MethodHandle SQR_HANDLE_7U_BULK_WITH_OFFSETS;
Expand Down Expand Up @@ -525,6 +605,18 @@ private static void sqrf32BulkWithOffsets(
bulkOffsetScorer
);

DOT_HANDLE_I1I4 = lookup.findStatic(
JdkVectorSimilarityFunctions.class,
"dotProductI1I4",
MethodType.methodType(long.class, MemorySegment.class, MemorySegment.class, int.class)
);
DOT_HANDLE_I1I4_BULK = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProductI1I4Bulk", bulkScorer);
DOT_HANDLE_I1I4_BULK_WITH_OFFSETS = lookup.findStatic(
JdkVectorSimilarityFunctions.class,
"dotProductI1I4BulkWithOffsets",
bulkOffsetScorer
);

DOT_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProductF32", singleFloatScorer);
DOT_HANDLE_FLOAT32_BULK = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProductF32Bulk", bulkScorer);
DOT_HANDLE_FLOAT32_BULK_WITH_OFFSETS = lookup.findStatic(
Expand Down Expand Up @@ -560,6 +652,21 @@ public MethodHandle dotProductHandle7uBulkWithOffsets() {
return DOT_HANDLE_7U_BULK_WITH_OFFSETS;
}

@Override
public MethodHandle dotProductHandleI1I4() {
return DOT_HANDLE_I1I4;
}

@Override
public MethodHandle dotProductHandleI1I4Bulk() {
return DOT_HANDLE_I1I4_BULK;
}

@Override
public MethodHandle dotProductHandleI1I4BulkWithOffsets() {
return DOT_HANDLE_I1I4_BULK_WITH_OFFSETS;
}

@Override
public MethodHandle squareDistanceHandle7u() {
return SQR_HANDLE_7U;
Expand Down
Loading