Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
0d88d5b
initial withByteBufferSlices
ChrisHegarty Mar 19, 2026
4444d5d
Update docs/changelog/144557.yaml
ChrisHegarty Mar 19, 2026
6bfc44d
[CI] Auto commit changes from spotless
Mar 19, 2026
a9268d0
itr
ChrisHegarty Mar 19, 2026
c8033dd
remove unnecessary path
ChrisHegarty Mar 19, 2026
92d40d7
rename gather to sparse
ChrisHegarty Mar 19, 2026
686c156
Refactor native mappers to return pointers directly
ChrisHegarty Mar 19, 2026
00f1c85
Merge branch 'main' into withByteBufferSlices
ChrisHegarty Mar 23, 2026
ee98fe0
restore
ChrisHegarty Mar 23, 2026
5d00766
bump version
ChrisHegarty Mar 23, 2026
8f2ef5e
fix and publish
ChrisHegarty Mar 23, 2026
babca93
use ValueLayout.ADDRESS - works on 32 and 64 bit platforms
ChrisHegarty Mar 23, 2026
b9d9d86
validate addresses
ChrisHegarty Mar 23, 2026
39798e4
[CI] Auto commit changes from spotless
Mar 23, 2026
eec44f3
more asserts
ChrisHegarty Mar 23, 2026
c30918a
Merge branch 'main' into withByteBufferSlices
ChrisHegarty Mar 23, 2026
63c10a3
update tests
ChrisHegarty Mar 23, 2026
100e31d
use org.hamcrest.Matchers.instanceOf
ChrisHegarty Mar 23, 2026
ce567f7
[CI] Auto commit changes from spotless
Mar 23, 2026
f27b600
int64_t* -> void* const* addresses
ChrisHegarty Mar 23, 2026
debc09f
bump library version
ChrisHegarty Mar 23, 2026
4cf3800
Merge branch 'main' into withByteBufferSlices
ChrisHegarty Mar 23, 2026
03f5eb1
Merge branch 'main' into withByteBufferSlices
ChrisHegarty Mar 24, 2026
d989532
Merge remote-tracking branch 'chegar/withByteBufferSlices' into withB…
ChrisHegarty Mar 24, 2026
28f9538
test comments
ChrisHegarty Mar 24, 2026
e440ebb
scoring with 0 numNodes
ChrisHegarty Mar 24, 2026
e47f813
checkargs
ChrisHegarty Mar 24, 2026
d4e659f
Merge branch 'main' into withByteBufferSlices
ChrisHegarty Mar 24, 2026
03448f2
itr
ChrisHegarty Mar 24, 2026
cede413
Merge branch 'main' into withByteBufferSlices
ChrisHegarty Mar 24, 2026
f75b960
Merge branch 'main' into withByteBufferSlices
ChrisHegarty Mar 25, 2026
14bfaeb
Merge branch 'main' into withByteBufferSlices
ChrisHegarty Mar 25, 2026
c7bdb13
Merge remote-tracking branch 'chegar/withByteBufferSlices' into withB…
ChrisHegarty Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/144557.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
area: Vector Search
issues: []
pr: 144557
summary: Add bulk-gather native vector scoring for searchable snapshots via `DirectAccessInput`
type: enhancement
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,41 @@ public interface DirectAccessInput {
* @return {@code true} if a buffer was available and the action was invoked
*/
boolean withByteBufferSlice(long offset, long length, CheckedConsumer<ByteBuffer, IOException> action) throws IOException;

/**
* Bulk variant of {@link #withByteBufferSlice}. Resolves {@code count}
* file ranges to direct byte buffers and invokes the action while all
* buffers are valid. All ref-counting and resource management is handled
* internally.
*
* <p> The byte buffers in the array passed to the action are read-only and
* valid only for the duration of the action. Callers must not retain
* references to them after the action returns.
*
* @param offsets file byte offsets for each range
* @param length byte length of each range (same for all)
* @param count number of ranges to resolve
* @param action receives a {@code ByteBuffer[]} where entry {@code i}
* corresponds to {@code offsets[i]}
* @return {@code true} if all ranges were available and the action was
* invoked; {@code false} otherwise
*/
boolean withByteBufferSlices(long[] offsets, int length, int count, CheckedConsumer<ByteBuffer[], IOException> action)
throws IOException;

/**
* Validates the {@code offsets} and {@code count} arguments for
* {@link #withByteBufferSlices}. Throws on negative count or an
* undersized offsets array. Returns {@code true} if count is zero
* (caller should treat as a no-op), {@code false} otherwise.
*/
static boolean checkSlicesArgs(long[] offsets, int count) {
if (count < 0) {
throw new IllegalArgumentException("count must not be negative, got " + count);
}
if (offsets.length < count) {
throw new IllegalArgumentException("offsets array length " + offsets.length + " is less than count " + count);
}
return count == 0;
}
}
2 changes: 1 addition & 1 deletion libs/native/libraries/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ configurations {
}

var zstdVersion = "1.5.7"
var vecVersion = "1.0.62"
var vecVersion = "1.0.63"

repositories {
exclusiveContent {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,21 @@ enum Operation {
* <li>Score results, as 4-byte floats, in order of iteration through the offset array</li>
* </ol>
*/
BULK_OFFSETS
BULK_OFFSETS,
/**
* Scores multiple vectors against a single vector, using an array of direct memory addresses
* to locate each vector.
* <p>
* Method handle takes arguments {@code (MemorySegment, MemorySegment, int, int, MemorySegment)}:
* <ol>
* <li>Array of 8-byte longs containing the native memory address of each vector</li>
* <li>Single vector to score against</li>
* <li>Number of dimensions, or for bbq, the number of index bytes</li>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't for BBQ (yet?)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just didn't write the native code for it yet, but given how this is progressing - the native mapper template should be trivial. lemme take a look.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BBQ can use a similar technique, but the code is a bit more involved. Let's do it as a follow up.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do this for BBQ/DiskBBQ? I think that in that case data is always contiguous...

* <li>Number of vectors to score</li>
* <li>Score results, as 4-byte floats</li>
* </ol>
*/
BULK_SPARSE
}

MethodHandle getHandle(Function function, DataType dataType, Operation operation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ private static MethodHandle bindFunction(String functionName, int capability, Fu
ADDRESS
);

FunctionDescriptor bulkSparse = FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS);

for (Function f : Function.values()) {
String funcName = switch (f) {
case COSINE -> "cos";
Expand All @@ -117,6 +119,7 @@ private static MethodHandle bindFunction(String functionName, int capability, Fu
case SINGLE -> "";
case BULK -> "_bulk";
case BULK_OFFSETS -> "_bulk_offsets";
case BULK_SPARSE -> "_bulk_sparse";
};

for (DataType type : DataType.values()) {
Expand All @@ -126,6 +129,8 @@ private static MethodHandle bindFunction(String functionName, int capability, Fu
// Only DOT_PRODUCT is needed for int4 — other functions are computed by
// applying correction terms on top of the raw dot product result.
if (f != Function.DOT_PRODUCT && type == DataType.INT4) continue;
// BULK_SPARSE only for INT7U and INT8 — no native sparse functions exist for FLOAT32 or INT4
if (op == Operation.BULK_SPARSE && (type == DataType.FLOAT32 || type == DataType.INT4)) continue;

String typeName = switch (type) {
case INT7U -> "i7u";
Expand All @@ -139,7 +144,7 @@ private static MethodHandle bindFunction(String functionName, int capability, Fu
case INT7U, INT4 -> intSingle;
case INT8, FLOAT32 -> floatSingle;
};
case BULK -> bulk;
case BULK, BULK_SPARSE -> bulk;
case BULK_OFFSETS -> bulkOffsets;
};

Expand All @@ -150,6 +155,8 @@ private static MethodHandle bindFunction(String functionName, int capability, Fu
for (BBQType type : BBQType.values()) {
// not implemented yet...
if (f == Function.COSINE || f == Function.SQUARE_DISTANCE) continue;
// BULK_SPARSE not yet implemented for BBQ
if (op == Operation.BULK_SPARSE) continue;

String typeName = switch (type) {
case D1Q4 -> "d1q4";
Expand All @@ -159,7 +166,7 @@ private static MethodHandle bindFunction(String functionName, int capability, Fu

FunctionDescriptor descriptor = switch (op) {
case SINGLE -> longSingle;
case BULK -> bulk;
case BULK, BULK_SPARSE -> bulk;
case BULK_OFFSETS -> bulkOffsets;
};

Expand Down Expand Up @@ -289,6 +296,22 @@ static boolean checkBBQBulk(
return true;
}

static boolean checkBulkSparse(
int elementBits,
MemorySegment addresses,
MemorySegment b,
int length,
int count,
MemorySegment result
) {
assert elementBits % 8 == 0 : "requires byte-aligned element types";
Objects.checkFromIndexSize(0L, (long) count * Long.BYTES, addresses.byteSize());
Objects.checkFromIndexSize(0L, (long) length * elementBits / 8, b.byteSize());
Objects.checkFromIndexSize(0L, (long) count * Float.BYTES, result.byteSize());
assert validateBulkSparse(addresses, count, length, elementBits, result);
return true;
}

static boolean checkBulkOffsets(
int elementBits,
MemorySegment a,
Expand Down Expand Up @@ -330,6 +353,29 @@ static boolean checkBBQBulkOffsets(
return true;
}

static boolean validateBulkSparse(MemorySegment addresses, int count, int length, int elementBits, MemorySegment result) {
if (count < 0) throw new IllegalArgumentException("count must be non-negative: " + count);
if (length <= 0) throw new IllegalArgumentException("length must be positive: " + length);
checkSegmentAlignment(addresses, Long.BYTES, "addresses", "long");
checkSegmentAlignment(result, Float.BYTES, "result", "float");
long vectorBytes = (long) length * elementBits / 8;
for (int i = 0; i < count; i++) {
long addr = addresses.getAtIndex(JAVA_LONG, i);
if (addr == 0) {
throw new IllegalArgumentException("address at index " + i + " is null");
}
MemorySegment vec = MemorySegment.ofAddress(addr).reinterpret(vectorBytes);
Objects.checkFromIndexSize(0L, vectorBytes, vec.byteSize());
}
return true;
}

private static void checkSegmentAlignment(MemorySegment segment, int alignment, String name, String type) {
if (segment.address() % alignment != 0) {
throw new IllegalArgumentException(name + " segment not aligned to " + type + " boundary");
}
}

private static final MethodHandle dotI7uHandle = HANDLES.get(
new OperationSignature<>(Function.DOT_PRODUCT, DataType.INT7U, Operation.SINGLE)
);
Expand Down Expand Up @@ -652,6 +698,33 @@ private static float applyCorrectionsDotProductBulk(

handlesWithChecks.put(op.getKey(), handleWithChecks);
}
case BULK_SPARSE -> {
MethodHandle handleWithChecks = switch (op.getKey().dataType()) {
case DataType dt -> {
MethodHandle checkMethod = lookup.findStatic(
JdkVectorSimilarityFunctions.class,
"checkBulkSparse",
MethodType.methodType(
boolean.class,
int.class,
MemorySegment.class,
MemorySegment.class,
int.class,
int.class,
MemorySegment.class
)
);
yield MethodHandles.guardWithTest(
MethodHandles.insertArguments(checkMethod, 0, dt.bits()),
op.getValue(),
MethodHandles.empty(op.getValue().type())
);
}
default -> throw new IllegalArgumentException("Unknown handle type " + op.getKey().dataType());
};

handlesWithChecks.put(op.getKey(), handleWithChecks);
}
case BULK_OFFSETS -> {
MethodHandle handleWithChecks = switch (op.getKey().dataType()) {
case BBQType bbq -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,67 @@ public void testInt7uBulkWithIdentityOffsets() {
assertScoresEquals(expectedScores, bulkScoresSeg);
}

// Tests bulk sparse similarity where vector addresses are slices of a single contiguous segment,
// verifying correct lookup and scoring via an address array with random ordinals.
public void testInt7uBulkSparse() {
assumeTrue(notSupportedMsg(), supported());
final int dims = size;
final int numVecs = randomIntBetween(2, 101);
var ordinals = new int[numVecs];
var vectors = new byte[numVecs][dims];
var vectorsSegment = arena.allocate((long) dims * numVecs);
var addressesSeg = arena.allocate(ValueLayout.ADDRESS.byteSize() * numVecs, ValueLayout.ADDRESS.byteAlignment());
for (int i = 0; i < numVecs; i++) {
ordinals[i] = randomInt(numVecs - 1);
randomBytesBetween(vectors[i], MIN_INT7_VALUE, MAX_INT7_VALUE);
MemorySegment.copy(vectors[i], 0, vectorsSegment, ValueLayout.JAVA_BYTE, (long) i * dims, dims);
}
for (int i = 0; i < numVecs; i++) {
addressesSeg.setAtIndex(ValueLayout.ADDRESS, i, vectorsSegment.asSlice((long) ordinals[i] * dims, dims));
}
int queryOrd = randomInt(numVecs - 1);
float[] expectedScores = new float[numVecs];
scalarSimilarityBulkWithOffsets(vectors[queryOrd], vectors, ordinals, expectedScores);

var nativeQuerySeg = vectorsSegment.asSlice((long) queryOrd * dims, dims);
var bulkScoresSeg = arena.allocate((long) numVecs * Float.BYTES);

similarityBulkSparse(addressesSeg, nativeQuerySeg, dims, numVecs, bulkScoresSeg);
assertScoresEquals(expectedScores, bulkScoresSeg);
}

// Tests bulk sparse similarity where each vector lives in its own independently allocated segment,
// ensuring the sparse path handles non-contiguous (scattered) memory correctly.
public void testInt7uBulkSparseScattered() {
assumeTrue(notSupportedMsg(), supported());
final int dims = size;
final int numVecs = randomIntBetween(2, 101);
var ordinals = new int[numVecs];
var vectors = new byte[numVecs][dims];
var segments = new MemorySegment[numVecs];
for (int i = 0; i < numVecs; i++) {
randomBytesBetween(vectors[i], MIN_INT7_VALUE, MAX_INT7_VALUE);
segments[i] = arena.allocate(dims);
MemorySegment.copy(vectors[i], 0, segments[i], ValueLayout.JAVA_BYTE, 0L, dims);
}
for (int i = 0; i < numVecs; i++) {
ordinals[i] = randomInt(numVecs - 1);
}
int queryOrd = randomInt(numVecs - 1);
float[] expectedScores = new float[numVecs];
scalarSimilarityBulkWithOffsets(vectors[queryOrd], vectors, ordinals, expectedScores);

var addressesSeg = arena.allocate(ValueLayout.ADDRESS.byteSize() * numVecs, ValueLayout.ADDRESS.byteAlignment());
for (int i = 0; i < numVecs; i++) {
addressesSeg.setAtIndex(ValueLayout.ADDRESS, i, segments[ordinals[i]]);
}
var nativeQuerySeg = segments[queryOrd];
var bulkScoresSeg = arena.allocate((long) numVecs * Float.BYTES);

similarityBulkSparse(addressesSeg, nativeQuerySeg, dims, numVecs, bulkScoresSeg);
assertScoresEquals(expectedScores, bulkScoresSeg);
}

public void testInt7uBulkWithOffsetsAndPitch() {
assumeTrue(notSupportedMsg(), supported());
final int dims = size;
Expand Down Expand Up @@ -289,6 +350,38 @@ public void testBulkIllegalDims() {
assertThat(ex.getMessage(), containsString("out of bounds for length"));
}

// Verifies that bulk sparse similarity rejects invalid arguments (undersized segments,
// negative dims/count) with appropriate out-of-bounds exceptions.
public void testBulkSparseIllegalArgs() {
assumeTrue(notSupportedMsg(), supported());
int count = 3;
var addresses = arena.allocate(ValueLayout.ADDRESS.byteSize() * count, ValueLayout.ADDRESS.byteAlignment());
var query = arena.allocate(size);
var scores = arena.allocate((long) count * Float.BYTES);

var tooSmallAddrs = arena.allocate(ValueLayout.ADDRESS.byteSize() * count - 1);
Exception ex = expectThrows(IOOBE, () -> similarityBulkSparse(tooSmallAddrs, query, size, count, scores));
assertThat(ex.getMessage(), containsString("out of bounds for length"));

var tooSmallQuery = arena.allocate(size - 1);
ex = expectThrows(IOOBE, () -> similarityBulkSparse(addresses, tooSmallQuery, size, count, scores));
assertThat(ex.getMessage(), containsString("out of bounds for length"));

var tooSmallScores = arena.allocate((long) count * Float.BYTES - 1);
ex = expectThrows(IOOBE, () -> similarityBulkSparse(addresses, query, size, count, tooSmallScores));
assertThat(ex.getMessage(), containsString("out of bounds for length"));

ex = expectThrows(IOOBE, () -> similarityBulkSparse(addresses, query, size, -1, scores));
assertThat(ex.getMessage(), containsString("out of bounds for length"));

ex = expectThrows(IOOBE, () -> similarityBulkSparse(addresses, query, -1, count, scores));
assertThat(ex.getMessage(), containsString("out of bounds for length"));

// null (zero) address in the addresses segment
ex = expectThrows(IAE, () -> similarityBulkSparse(addresses, query, size, count, scores));
assertThat(ex.getMessage(), containsString("null"));
}

int similarity(MemorySegment a, MemorySegment b, int length) {
try {
return (int) getVectorDistance().getHandle(
Expand Down Expand Up @@ -330,6 +423,18 @@ void similarityBulkWithOffsets(
}
}

void similarityBulkSparse(MemorySegment addresses, MemorySegment b, int dims, int count, MemorySegment result) {
try {
getVectorDistance().getHandle(
function,
VectorSimilarityFunctions.DataType.INT7U,
VectorSimilarityFunctions.Operation.BULK_SPARSE
).invokeExact(addresses, b, dims, count, result);
} catch (Throwable t) {
throw rethrow(t);
}
}

int scalarSimilarity(byte[] a, byte[] b) {
return switch (function) {
case DOT_PRODUCT -> dotProductScalar(a, b);
Expand Down
Loading
Loading