Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
d31d5fc
Moving caps to a separate file; move some common function to headers
ldematte Jan 22, 2026
b1a96e2
Add (disk)(bbq) scoring adjustment functions. Simple C++ for ARM, AVX…
ldematte Jan 22, 2026
68d1496
Merge remote-tracking branch 'upstream/main' into native/diskbbq-scoring
ldematte Jan 23, 2026
077e291
Small fix to native
ldematte Jan 23, 2026
20eeb7b
Wiring
ldematte Jan 23, 2026
58ee671
Improve test suite (randomization -> parametrization)
ldematte Jan 23, 2026
aaf3492
fixes
ldematte Jan 26, 2026
4a2ced8
Add initial AVX-512 implementation
ldematte Jan 26, 2026
2bb9132
Merge remote-tracking branch 'upstream/main' into native/diskbbq-scoring
ldematte Jan 27, 2026
68d6ffb
Use native scoring for x64 only
ldematte Jan 27, 2026
fb92b57
Enable for ARM too (native, non-vector)
ldematte Jan 27, 2026
219a0c0
Renaming
ldematte Jan 27, 2026
5f5eb10
Fix after merge
ldematte Jan 27, 2026
3a790ec
Merge remote-tracking branch 'upstream/main' into native/diskbbq-scoring
ldematte Jan 27, 2026
03ef882
Distance function as benchmark param
ldematte Jan 27, 2026
a58d21d
Remove not performant AVX-512 implementations
ldematte Jan 28, 2026
447cd0b
Extract common pointer unpacking
ldematte Jan 28, 2026
543f20b
Other PR feedback
ldematte Jan 28, 2026
8082e08
Merge remote-tracking branch 'upstream/main' into native/diskbbq-scoring
ldematte Jan 28, 2026
4b23b46
spotless
ldematte Jan 28, 2026
9efb846
Add indexBitScale for future compatibility (int2, int4)
ldematte Jan 29, 2026
ef8d745
Merge remote-tracking branch 'upstream/main' into native/diskbbq-scoring
ldematte Jan 29, 2026
98e8b42
Merge remote-tracking branch 'upstream/main' into native/diskbbq-scoring
ldematte Feb 4, 2026
725d00a
Update native code for int32 targetComponentSum
ldematte Feb 4, 2026
9ce5f46
Fix
ldematte Feb 4, 2026
6332c79
Merge remote-tracking branch 'upstream/main' into native/diskbbq-scoring
ldematte Feb 4, 2026
9ee4c6e
Publish vec binaries + update version
ldematte Feb 4, 2026
edc74dd
Fix
ldematte Feb 4, 2026
2609f81
Merge remote-tracking branch 'upstream/main' into native/diskbbq-scoring
ldematte Feb 4, 2026
83fd6ac
test fix
ldematte Feb 4, 2026
b68805a
Handling errors for values close to 0
ldematte Feb 5, 2026
59ec364
Merge remote-tracking branch 'upstream/main' into native/diskbbq-scoring
ldematte Feb 5, 2026
d77ed93
iter
ldematte Feb 5, 2026
2e63222
[CI] Auto commit changes from spotless
Feb 5, 2026
9557e3e
Merge remote-tracking branch 'upstream/main' into native/diskbbq-scoring
ldematte Feb 6, 2026
3b0511d
Renaming + const fixes
ldematte Feb 7, 2026
1570d8c
Publish vec binaries + update version
ldematte Feb 7, 2026
729edfd
D2Q4 impl as well
ldematte Feb 7, 2026
54009de
[CI] Auto commit changes from spotless
Feb 7, 2026
c9b0990
Adjust comments
ldematte Feb 7, 2026
2bcbb53
spotless
ldematte Feb 7, 2026
84c5240
Merge branch 'native/diskbbq-scoring' of github.com:ldematte/elastics…
ldematte Feb 7, 2026
87264b5
Merge remote-tracking branch 'upstream/main' into native/diskbbq-scoring
ldematte Feb 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/native/libraries/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ configurations {
}

var zstdVersion = "1.5.7"
var vecVersion = "1.0.34"
var vecVersion = "1.0.36"

repositories {
exclusiveContent {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,10 @@ enum Operation {
MethodHandle getHandle(Function function, DataType dataType, Operation operation);

MethodHandle getHandle(Function function, BBQType bbqType, Operation operation);

MethodHandle applyCorrectionsEuclideanBulk();

MethodHandle applyCorrectionsMaxInnerProductBulk();

MethodHandle applyCorrectionsDotProductBulk();
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ private record OperationSignature<E extends Enum<E>>(Function function, E dataTy

private static final Map<OperationSignature<?>, MethodHandle> HANDLES;

static final MethodHandle applyCorrectionsEuclideanBulk$mh;
static final MethodHandle applyCorrectionsMaxInnerProductBulk$mh;
static final MethodHandle applyCorrectionsDotProductBulk$mh;

private static final JdkVectorSimilarityFunctions INSTANCE;

/**
Expand Down Expand Up @@ -156,6 +160,26 @@ private static MethodHandle bindFunction(String functionName, int capability, Fu
}

HANDLES = Collections.unmodifiableMap(handles);

FunctionDescriptor score = FunctionDescriptor.of(
JAVA_FLOAT,
ADDRESS, // corrections
JAVA_INT, // bulkSize,
JAVA_INT, // dimensions,
JAVA_FLOAT, // queryLowerInterval,
JAVA_FLOAT, // queryUpperInterval,
JAVA_INT, // queryComponentSum,
JAVA_FLOAT, // queryAdditionalCorrection,
JAVA_FLOAT, // queryBitScale,
JAVA_FLOAT, // indexBitScale,
JAVA_FLOAT, // centroidDp,
ADDRESS // scores
);

applyCorrectionsEuclideanBulk$mh = bindFunction("diskbbq_apply_corrections_euclidean_bulk", caps, score);
applyCorrectionsMaxInnerProductBulk$mh = bindFunction("diskbbq_apply_corrections_maximum_inner_product_bulk", caps, score);
applyCorrectionsDotProductBulk$mh = bindFunction("diskbbq_apply_corrections_dot_product_bulk", caps, score);

INSTANCE = new JdkVectorSimilarityFunctions();
} else {
if (caps < 0) {
Expand All @@ -164,6 +188,9 @@ private static MethodHandle bindFunction(String functionName, int capability, Fu
enable them in your OS/Hypervisor/VM/container""");
}
HANDLES = null;
applyCorrectionsEuclideanBulk$mh = null;
applyCorrectionsMaxInnerProductBulk$mh = null;
applyCorrectionsDotProductBulk$mh = null;
INSTANCE = null;
}
} catch (Throwable t) {
Expand Down Expand Up @@ -379,8 +406,108 @@ private static void checkByteSize(MemorySegment a, MemorySegment b) {
}
}

private static float applyCorrectionsEuclideanBulk(
MemorySegment corrections,
int bulkSize,
int dimensions,
float queryLowerInterval,
float queryUpperInterval,
int queryComponentSum,
float queryAdditionalCorrection,
float queryBitScale,
float indexBitScale,
float centroidDp,
MemorySegment scores
) {
try {
return (float) applyCorrectionsEuclideanBulk$mh.invokeExact(
corrections,
bulkSize,
dimensions,
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
queryBitScale,
indexBitScale,
centroidDp,
scores
);
} catch (Throwable t) {
throw new AssertionError(t);
}
}

private static float applyCorrectionsMaxInnerProductBulk(
MemorySegment corrections,
int bulkSize,
int dimensions,
float queryLowerInterval,
float queryUpperInterval,
int queryComponentSum,
float queryAdditionalCorrection,
float queryBitScale,
float indexBitScale,
float centroidDp,
MemorySegment scores
) {
try {
return (float) applyCorrectionsMaxInnerProductBulk$mh.invokeExact(
corrections,
bulkSize,
dimensions,
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
queryBitScale,
indexBitScale,
centroidDp,
scores
);
} catch (Throwable t) {
throw new AssertionError(t);
}
}

private static float applyCorrectionsDotProductBulk(
MemorySegment corrections,
int bulkSize,
int dimensions,
float queryLowerInterval,
float queryUpperInterval,
int queryComponentSum,
float queryAdditionalCorrection,
float queryBitScale,
float indexBitScale,
float centroidDp,
MemorySegment scores
) {
try {
return (float) applyCorrectionsDotProductBulk$mh.invokeExact(
corrections,
bulkSize,
dimensions,
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
queryBitScale,
indexBitScale,
centroidDp,
scores
);
} catch (Throwable t) {
throw new AssertionError(t);
}
}

private static final Map<OperationSignature<?>, MethodHandle> HANDLES_WITH_CHECKS;

static final MethodHandle APPLY_CORRECTIONS_EUCLIDEAN_HANDLE_BULK;
static final MethodHandle APPLY_CORRECTIONS_MAX_INNER_PRODUCT_HANDLE_BULK;
static final MethodHandle APPLY_CORRECTIONS_DOT_PRODUCT_HANDLE_BULK;

static {
MethodHandles.Lookup lookup = MethodHandles.lookup();

Expand Down Expand Up @@ -535,6 +662,37 @@ private static void checkByteSize(MemorySegment a, MemorySegment b) {
}

HANDLES_WITH_CHECKS = Collections.unmodifiableMap(handlesWithChecks);

MethodType scoringFunction = MethodType.methodType(
float.class,
MemorySegment.class,
int.class,
int.class,
float.class,
float.class,
int.class,
float.class,
float.class,
float.class,
float.class,
MemorySegment.class
);

APPLY_CORRECTIONS_EUCLIDEAN_HANDLE_BULK = lookup.findStatic(
JdkVectorSimilarityFunctions.class,
"applyCorrectionsEuclideanBulk",
scoringFunction
);
APPLY_CORRECTIONS_MAX_INNER_PRODUCT_HANDLE_BULK = lookup.findStatic(
JdkVectorSimilarityFunctions.class,
"applyCorrectionsMaxInnerProductBulk",
scoringFunction
);
APPLY_CORRECTIONS_DOT_PRODUCT_HANDLE_BULK = lookup.findStatic(
JdkVectorSimilarityFunctions.class,
"applyCorrectionsDotProductBulk",
scoringFunction
);
} catch (ReflectiveOperationException e) {
throw new AssertionError(e);
}
Expand All @@ -555,5 +713,20 @@ public MethodHandle getHandle(Function function, BBQType bbqType, Operation oper
if (mh == null) throw new IllegalArgumentException("Signature not implemented: " + key);
return mh;
}

@Override
public MethodHandle applyCorrectionsEuclideanBulk() {
return APPLY_CORRECTIONS_EUCLIDEAN_HANDLE_BULK;
}

@Override
public MethodHandle applyCorrectionsMaxInnerProductBulk() {
return APPLY_CORRECTIONS_MAX_INNER_PRODUCT_HANDLE_BULK;
}

@Override
public MethodHandle applyCorrectionsDotProductBulk() {
return APPLY_CORRECTIONS_DOT_PRODUCT_HANDLE_BULK;
}
}
}
2 changes: 1 addition & 1 deletion libs/simdvec/native/publish_vec_binaries.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
exit 1;
fi

VERSION="1.0.34"
VERSION="1.0.36"
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
TEMP=$(mktemp -d)

Expand Down
141 changes: 141 additions & 0 deletions libs/simdvec/native/src/vec/c/aarch64/score_1.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

#include <stddef.h>
#include <arm_neon.h>
#include <math.h>
#include <limits>
#include "vec.h"
#include "vec_common.h"
#include "aarch64/aarch64_vec_common.h"

#include "score_common.h"

EXPORT f32_t diskbbq_apply_corrections_euclidean_bulk(
const int8_t* corrections,
const int32_t bulkSize,
const int32_t dimensions,
const f32_t queryLowerInterval,
const f32_t queryUpperInterval,
const int32_t queryComponentSum,
const f32_t queryAdditionalCorrection,
const f32_t queryBitScale,
const f32_t indexBitScale,
const f32_t centroidDp,
f32_t* scores
) {
f32_t maxScore = -std::numeric_limits<f32_t>::infinity();

const corrections_t c = unpack_corrections(corrections, bulkSize);

int i = 0;
for (; i < bulkSize; ++i) {
f32_t score = apply_corrections_euclidean_inner(
dimensions,
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
queryBitScale,
indexBitScale,
centroidDp,
*(c.lowerIntervals + i),
*(c.upperIntervals + i),
*(c.targetComponentSums + i),
*(c.additionalCorrections + i),
*(scores + i)
);
*(scores + i) = score;
maxScore = fmax(maxScore, score);
}

return maxScore;
}

EXPORT f32_t diskbbq_apply_corrections_maximum_inner_product_bulk(
const int8_t* corrections,
const int32_t bulkSize,
const int32_t dimensions,
const f32_t queryLowerInterval,
const f32_t queryUpperInterval,
const int32_t queryComponentSum,
const f32_t queryAdditionalCorrection,
const f32_t queryBitScale,
const f32_t indexBitScale,
const f32_t centroidDp,
f32_t* scores
) {
f32_t maxScore = -std::numeric_limits<f32_t>::infinity();

const corrections_t c = unpack_corrections(corrections, bulkSize);

int i = 0;
for (; i < bulkSize; ++i) {
f32_t score = apply_corrections_maximum_inner_product_inner(
dimensions,
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
queryBitScale,
indexBitScale,
centroidDp,
*(c.lowerIntervals + i),
*(c.upperIntervals + i),
*(c.targetComponentSums + i),
*(c.additionalCorrections + i),
*(scores + i)
);
*(scores + i) = score;
maxScore = fmax(maxScore, score);
}

return maxScore;
}

EXPORT f32_t diskbbq_apply_corrections_dot_product_bulk(
const int8_t* corrections,
const int32_t bulkSize,
const int32_t dimensions,
const f32_t queryLowerInterval,
const f32_t queryUpperInterval,
const int32_t queryComponentSum,
const f32_t queryAdditionalCorrection,
const f32_t queryBitScale,
const f32_t indexBitScale,
const f32_t centroidDp,
f32_t* scores
) {
f32_t maxScore = -std::numeric_limits<f32_t>::infinity();

const corrections_t c = unpack_corrections(corrections, bulkSize);

int i = 0;
for (; i < bulkSize; ++i) {
f32_t score = apply_corrections_dot_product_inner(
dimensions,
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
queryBitScale,
indexBitScale,
centroidDp,
*(c.lowerIntervals + i),
*(c.upperIntervals + i),
*(c.targetComponentSums + i),
*(c.additionalCorrections + i),
*(scores + i)
);
*(scores + i) = score;
maxScore = fmax(maxScore, score);
}

return maxScore;
}
Loading