Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -134,27 +134,27 @@ public void teardown() throws IOException {
IOUtils.close(dirMmap, inMmap, dirNiofs, inNiofs);
}

@Benchmark
public void scoreFromMemorySegmentOnlyVectorMmapScalar(Blackhole bh) throws IOException {
scoreFromMemorySegmentOnlyVector(bh, inMmap, scorerMmap);
}
// @Benchmark
// public void scoreFromMemorySegmentOnlyVectorMmapScalar(Blackhole bh) throws IOException {
// scoreFromMemorySegmentOnlyVector(bh, inMmap, scorerMmap);
// }

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void scoreFromMemorySegmentOnlyVectorMmapVect(Blackhole bh) throws IOException {
scoreFromMemorySegmentOnlyVector(bh, inMmap, scorerMmap);
}

@Benchmark
public void scoreFromMemorySegmentOnlyVectorNiofsScalar(Blackhole bh) throws IOException {
scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios);
}
// @Benchmark
// public void scoreFromMemorySegmentOnlyVectorNiofsScalar(Blackhole bh) throws IOException {
// scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios);
// }

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void scoreFromMemorySegmentOnlyVectorNiofsVect(Blackhole bh) throws IOException {
scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios);
}
// @Benchmark
// @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
// public void scoreFromMemorySegmentOnlyVectorNiofsVect(Blackhole bh) throws IOException {
// scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios);
// }

private void scoreFromMemorySegmentOnlyVector(Blackhole bh, IndexInput in, ES91OSQVectorsScorer scorer) throws IOException {
for (int j = 0; j < numQueries; j++) {
Expand All @@ -181,27 +181,27 @@ private void scoreFromMemorySegmentOnlyVector(Blackhole bh, IndexInput in, ES91O
}
}

@Benchmark
public void scoreFromMemorySegmentOnlyVectorBulkMmapScalar(Blackhole bh) throws IOException {
scoreFromMemorySegmentOnlyVectorBulk(bh, inMmap, scorerMmap);
}
// @Benchmark
// public void scoreFromMemorySegmentOnlyVectorBulkMmapScalar(Blackhole bh) throws IOException {
// scoreFromMemorySegmentOnlyVectorBulk(bh, inMmap, scorerMmap);
// }

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void scoreFromMemorySegmentOnlyVectorBulkMmapVect(Blackhole bh) throws IOException {
scoreFromMemorySegmentOnlyVectorBulk(bh, inMmap, scorerMmap);
}

@Benchmark
public void scoreFromMemorySegmentOnlyVectorBulkNiofsScalar(Blackhole bh) throws IOException {
scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios);
}
// @Benchmark
// public void scoreFromMemorySegmentOnlyVectorBulkNiofsScalar(Blackhole bh) throws IOException {
// scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios);
// }

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void scoreFromMemorySegmentOnlyVectorBulkNiofsVect(Blackhole bh) throws IOException {
scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios);
}
// @Benchmark
// @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
// public void scoreFromMemorySegmentOnlyVectorBulkNiofsVect(Blackhole bh) throws IOException {
// scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios);
// }

private void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh, IndexInput in, ES91OSQVectorsScorer scorer) throws IOException {
for (int j = 0; j < numQueries; j++) {
Expand Down Expand Up @@ -230,27 +230,27 @@ private void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh, IndexInput in, E
}
}

@Benchmark
public void scoreFromMemorySegmentAllBulkMmapScalar(Blackhole bh) throws IOException {
scoreFromMemorySegmentAllBulk(bh, inMmap, scorerMmap);
}
// @Benchmark
// public void scoreFromMemorySegmentAllBulkMmapScalar(Blackhole bh) throws IOException {
// scoreFromMemorySegmentAllBulk(bh, inMmap, scorerMmap);
// }

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void scoreFromMemorySegmentAllBulkMmapVect(Blackhole bh) throws IOException {
scoreFromMemorySegmentAllBulk(bh, inMmap, scorerMmap);
}

@Benchmark
public void scoreFromMemorySegmentAllBulkNiofsScalar(Blackhole bh) throws IOException {
scoreFromMemorySegmentAllBulk(bh, inNiofs, scorerNfios);
}
// @Benchmark
// public void scoreFromMemorySegmentAllBulkNiofsScalar(Blackhole bh) throws IOException {
// scoreFromMemorySegmentAllBulk(bh, inNiofs, scorerNfios);
// }

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void scoreFromMemorySegmentAllBulkNiofsVect(Blackhole bh) throws IOException {
scoreFromMemorySegmentAllBulk(bh, inNiofs, scorerNfios);
}
// @Benchmark
// @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
// public void scoreFromMemorySegmentAllBulkNiofsVect(Blackhole bh) throws IOException {
// scoreFromMemorySegmentAllBulk(bh, inNiofs, scorerNfios);
// }

private void scoreFromMemorySegmentAllBulk(Blackhole bh, IndexInput in, ES91OSQVectorsScorer scorer) throws IOException {
for (int j = 0; j < numQueries; j++) {
Expand Down
5 changes: 5 additions & 0 deletions docs/changelog/134623.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 134623
summary: Native OSQ scoring
area: Vector Search
type: enhancement
issues: []
2 changes: 1 addition & 1 deletion libs/native/libraries/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ dependencies {
libs "org.elasticsearch:zstd:${zstdVersion}:linux-aarch64"
libs "org.elasticsearch:zstd:${zstdVersion}:linux-x86-64"
libs "org.elasticsearch:zstd:${zstdVersion}:windows-x86-64"
libs "org.elasticsearch:vec:${vecVersion}@zip" // temporarily comment this out, if testing a locally built native lib
//libs "org.elasticsearch:vec:${vecVersion}@zip" // temporarily comment this out, if testing a locally built native lib
}

def extractLibs = tasks.register('extractLibs', Copy) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,8 @@ public interface VectorSimilarityFunctions {
* 4-byte float32 elements.
*/
MethodHandle squareDistanceHandleFloat32();

MethodHandle int4BitDotProductHandle();

MethodHandle int4BitDotProductBulkHandle();
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import static java.lang.foreign.ValueLayout.ADDRESS;
import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
import static java.lang.foreign.ValueLayout.JAVA_INT;
import static java.lang.foreign.ValueLayout.JAVA_LONG;
import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle;

public final class JdkVectorLibrary implements VectorLibrary {
Expand All @@ -36,6 +37,8 @@ public final class JdkVectorLibrary implements VectorLibrary {
static final MethodHandle cosf32$mh;
static final MethodHandle dotf32$mh;
static final MethodHandle sqrf32$mh;
static final MethodHandle int4Bit$mh;
static final MethodHandle int4BitBulk$mh;

public static final JdkVectorSimilarityFunctions INSTANCE;

Expand Down Expand Up @@ -100,6 +103,16 @@ public final class JdkVectorLibrary implements VectorLibrary {
LinkerHelperUtil.critical()
);
}
int4Bit$mh = downcallHandle(
"int4Bit",
FunctionDescriptor.of(JAVA_LONG, ADDRESS, ADDRESS, JAVA_LONG, JAVA_INT),
LinkerHelperUtil.critical()
);
int4BitBulk$mh = downcallHandle(
"int4BitBulk",
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_LONG, ADDRESS, JAVA_INT, JAVA_INT),
LinkerHelperUtil.critical()
);
INSTANCE = new JdkVectorSimilarityFunctions();
} else {
if (caps < 0) {
Expand All @@ -112,6 +125,8 @@ public final class JdkVectorLibrary implements VectorLibrary {
cosf32$mh = null;
dotf32$mh = null;
sqrf32$mh = null;
int4Bit$mh = null;
int4BitBulk$mh = null;
INSTANCE = null;
}
} catch (Throwable t) {
Expand Down Expand Up @@ -142,6 +157,34 @@ static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
return dot7u(a, b, length);
}

static long int4BitDotProd(MemorySegment a, MemorySegment b, long offset, int length) {
if (a.byteSize() != 4L * length) {
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + 4L * length);
}
return int4Bit(a, b, offset, length);
}

private static long int4Bit(MemorySegment a, MemorySegment b, long offset, int length) {
try {
return (long) JdkVectorLibrary.int4Bit$mh.invokeExact(a, b, offset, length);
} catch (Throwable t) {
throw new AssertionError(t);
}
}

static void int4BitDotProdBulk(MemorySegment a, MemorySegment b, long offset, MemorySegment s, int count, int length) {
assert length >= 0;
int4BitBulk(a, b, offset, s, count, length);
}

private static void int4BitBulk(MemorySegment a, MemorySegment b, long offset, MemorySegment s, int count, int length) {
try {
JdkVectorLibrary.int4BitBulk$mh.invokeExact(a, b, offset, s, count, length);
} catch (Throwable t) {
throw new AssertionError(t);
}
}

/**
* Computes the square distance of given unsigned int7 byte vectors.
*
Expand Down Expand Up @@ -247,6 +290,8 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
static final MethodHandle COS_HANDLE_FLOAT32;
static final MethodHandle DOT_HANDLE_FLOAT32;
static final MethodHandle SQR_HANDLE_FLOAT32;
static final MethodHandle DOT_HANDLE_4BIT;
static final MethodHandle DOT_HANDLE_4BIT_BULK;

static {
try {
Expand All @@ -259,6 +304,19 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
COS_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "cosineF32", mt);
DOT_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProductF32", mt);
SQR_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistanceF32", mt);
mt = MethodType.methodType(long.class, MemorySegment.class, MemorySegment.class, long.class, int.class);
DOT_HANDLE_4BIT = lookup.findStatic(JdkVectorSimilarityFunctions.class, "int4BitDotProd", mt);
mt = MethodType.methodType(
void.class,
MemorySegment.class,
MemorySegment.class,
long.class,
MemorySegment.class,
int.class,
int.class
);

DOT_HANDLE_4BIT_BULK = lookup.findStatic(JdkVectorSimilarityFunctions.class, "int4BitDotProdBulk", mt);
} catch (NoSuchMethodException | IllegalAccessException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -288,5 +346,15 @@ public MethodHandle dotProductHandleFloat32() {
public MethodHandle squareDistanceHandleFloat32() {
return SQR_HANDLE_FLOAT32;
}

@Override
public MethodHandle int4BitDotProductHandle() {
return DOT_HANDLE_4BIT;
}

@Override
public MethodHandle int4BitDotProductBulkHandle() {
return DOT_HANDLE_4BIT_BULK;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.nativeaccess.jdk;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.elasticsearch.nativeaccess.VectorSimilarityFunctionsTests;
import org.junit.AfterClass;
import org.junit.BeforeClass;

import java.lang.foreign.MemorySegment;

public class JDKVectorLibraryInt4BitTests extends VectorSimilarityFunctionsTests {

public JDKVectorLibraryInt4BitTests(int size) {
super(size);
}

@BeforeClass
public static void beforeClass() {
VectorSimilarityFunctionsTests.setup();
}

@AfterClass
public static void afterClass() {
VectorSimilarityFunctionsTests.cleanup();
}

@ParametersFactory
public static Iterable<Object[]> parametersFactory() {
return VectorSimilarityFunctionsTests.parametersFactory();
}

@Override
public boolean supported() {
if (super.supported()) {
var arch = System.getProperty("os.arch");
var osName = System.getProperty("os.name");
// only implemented in this architecture
return arch.equals("aarch64") && (osName.startsWith("Mac"));
}
return false;
}

private static int discretize(int value, int bucket) {
return ((value + (bucket - 1)) / bucket) * bucket;
}

public void testInt4Bin() {
assumeTrue(notSupportedMsg(), supported());
final int length = discretize(size, 64) / 8;
final int numVecs = randomIntBetween(2, 101);
var values = new byte[numVecs][length];
var segment = arena.allocate((long) numVecs * length);
for (int i = 0; i < numVecs; i++) {
random().nextBytes(values[i]);
MemorySegment.copy(MemorySegment.ofArray(values[i]), 0L, segment, (long) i * length, length);
}

final int loopTimes = 1000;
byte[] query = new byte[4 * length];
float[] scores = new float[numVecs];
float[] scoresExpected = new float[numVecs];
var querySegment = arena.allocate(4L * length);
for (int i = 0; i < loopTimes; i++) {
int ord = randomInt(numVecs - 1);
long offset = (long) ord * length;
random().nextBytes(query);
MemorySegment.copy(MemorySegment.ofArray(query), 0L, querySegment, 0, 4 * length);
for (int j = 0; j < numVecs; j++) {
scoresExpected[j] = int4BitScalar(query, values[j], length);
}
assertEquals(scoresExpected[ord], (float) int4Bit(querySegment, segment, offset, length), 0.0f);
int4BitBulk(querySegment, segment, 0L, MemorySegment.ofArray(scores), numVecs, length);
assertArrayEquals(scoresExpected, scores, 0.0f);
}
}

long int4Bit(MemorySegment a, MemorySegment b, long offset, int length) {
try {
return (long) getVectorDistance().int4BitDotProductHandle().invokeExact(a, b, offset, length);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}

void int4BitBulk(MemorySegment a, MemorySegment b, long offset, MemorySegment scores, int count, int length) {
try {
getVectorDistance().int4BitDotProductBulkHandle().invokeExact(a, b, offset, scores, count, length);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}

/** Computes the dot product of the given vectors a and b. */
static long int4BitScalar(byte[] a, byte[] b, int length) {
long subRet0 = 0;
long subRet1 = 0;
long subRet2 = 0;
long subRet3 = 0;
for (int r = 0; r < length; r++) {
final byte value = b[r];
subRet0 += Integer.bitCount((a[r] & value) & 0xFF);
subRet1 += Integer.bitCount((a[r + length] & value) & 0xFF);
subRet2 += Integer.bitCount((a[r + 2 * length] & value) & 0xFF);
subRet3 += Integer.bitCount((a[r + 3 * length] & value) & 0xFF);
}
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
}
}
Loading