Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.benchmark.Utils;
import org.elasticsearch.nativeaccess.NativeAccess;
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
Expand All @@ -21,11 +23,17 @@
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.rethrow;
import static org.elasticsearch.nativeaccess.Int4TestUtils.dotProductI4SinglePacked;
import static org.elasticsearch.nativeaccess.Int4TestUtils.packNibbles;
import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.randomInt4Bytes;
Expand All @@ -47,8 +55,17 @@ public class VectorScorerInt4OperationBenchmark {
Utils.configureBenchmarkLogging();
}

public byte[] unpacked;
public byte[] packed;
private int packedLen;

private byte[] unpacked;
private byte[] packed;

MemorySegment unpackedHeapSeg, packedHeapSeg;
MemorySegment unpackedNativeSeg, packedNativeSeg;

Arena arena;

private MethodHandle nativeImpl;

@Param({ "2", "128", "208", "256", "300", "512", "702", "1024", "1536", "2048" })
public int size;
Expand All @@ -58,6 +75,27 @@ public void init() {
unpacked = new byte[size];
randomInt4Bytes(ThreadLocalRandom.current(), unpacked);
packed = packNibbles(unpacked);
packedLen = packed.length;

unpackedHeapSeg = MemorySegment.ofArray(unpacked);
packedHeapSeg = MemorySegment.ofArray(packed);

arena = Arena.ofConfined();
unpackedNativeSeg = arena.allocate(unpacked.length);
MemorySegment.copy(unpacked, 0, unpackedNativeSeg, ValueLayout.JAVA_BYTE, 0L, unpacked.length);
packedNativeSeg = arena.allocate(packed.length);
MemorySegment.copy(packed, 0, packedNativeSeg, ValueLayout.JAVA_BYTE, 0L, packed.length);

nativeImpl = vectorSimilarityFunctions.getHandle(
VectorSimilarityFunctions.Function.DOT_PRODUCT,
VectorSimilarityFunctions.DataType.INT4,
VectorSimilarityFunctions.Operation.SINGLE
);
}

@TearDown
public void teardown() {
arena.close();
}

@Benchmark
Expand All @@ -69,4 +107,24 @@ public int scalar() {
public int lucene() {
return VectorUtil.int4DotProductSinglePacked(unpacked, packed);
}

@Benchmark
public int nativeWithNativeSeg() {
try {
return (int) nativeImpl.invokeExact(unpackedNativeSeg, packedNativeSeg, packedLen);
} catch (Throwable t) {
throw rethrow(t);
}
}

@Benchmark
public int nativeWithHeapSeg() {
try {
return (int) nativeImpl.invokeExact(unpackedHeapSeg, packedHeapSeg, packedLen);
} catch (Throwable t) {
throw rethrow(t);
}
}

static final VectorSimilarityFunctions vectorSimilarityFunctions = NativeAccess.instance().getVectorSimilarityFunctions().orElseThrow();
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import java.util.Arrays;

import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.supportsHeapSegments;

public class VectorScorerInt4BulkBenchmarkTests extends ESTestCase {

private final VectorSimilarityType function;
Expand All @@ -44,8 +46,9 @@ private VectorScorerInt4BulkBenchmark createBench(VectorImplementation impl, Vec
}

@BeforeClass
public static void skipWindows() {
public static void skipUnsupported() {
assumeFalse("doesn't work on windows yet", Constants.WINDOWS);
assumeTrue("native requires JDK22+", supportsHeapSegments());
}

public void testSequential() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import java.util.Arrays;

import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.supportsHeapSegments;

public class VectorScorerInt4OperationBenchmarkTests extends ESTestCase {

private final int size;
Expand All @@ -27,8 +29,9 @@ public VectorScorerInt4OperationBenchmarkTests(int size) {
}

@BeforeClass
public static void skipWindows() {
public static void skipUnsupported() {
assumeFalse("doesn't work on windows yet", Constants.WINDOWS);
assumeTrue("native requires JDK22+", supportsHeapSegments());
}

public void test() {
Expand All @@ -39,6 +42,8 @@ public void test() {

int expected = bench.scalar();
assertEquals(expected, bench.lucene());
assertEquals(expected, bench.nativeWithNativeSeg());
assertEquals(expected, bench.nativeWithHeapSeg());
}
}

Expand Down
Loading