diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/BenchmarkUtils.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/BenchmarkUtils.java index 9d0b411e566cc..ba443915cb952 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/BenchmarkUtils.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/BenchmarkUtils.java @@ -80,4 +80,11 @@ static RandomVectorScorer luceneScorer(QuantizedByteVectorValues values, VectorS throws IOException { return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorer(sim, values, queryVec); } + + static RuntimeException rethrow(Throwable t) { + if (t instanceof Error err) { + throw err; + } + return t instanceof RuntimeException re ? re : new RuntimeException(t); + } } diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/ScalarOperations.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/ScalarOperations.java new file mode 100644 index 0000000000000..3143c512cfe2d --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/ScalarOperations.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.vector.scorer; + +/** + * Basic scalar implementations of similarity operations. + *
+ * It is tricky to get specifically the scalar implementations from Lucene, + * as it tries to push into Panama implementations. So just re-implement them here. + */ +class ScalarOperations { + + static int dotProduct(byte[] a, byte[] b) { + int res = 0; + for (int i = 0; i < a.length; i++) { + res += a[i] * b[i]; + } + return res; + } + + static int squareDistance(byte[] a, byte[] b) { + int res = 0; + for (int i = 0; i < a.length; i++) { + int diff = a[i] - b[i]; + res += diff * diff; + } + return res; + } +} diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorImplementation.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorImplementation.java new file mode 100644 index 0000000000000..a4aa7ff4c752f --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorImplementation.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.vector.scorer; + +public enum VectorImplementation { + SCALAR, + LUCENE, + NATIVE +} diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerJDKFloat32Benchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerFloat32OperationBenchmark.java similarity index 53% rename from benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerJDKFloat32Benchmark.java rename to benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerFloat32OperationBenchmark.java index abb20d1f06141..e9e167f0d3a96 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerJDKFloat32Benchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerFloat32OperationBenchmark.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.logging.NodeNamePatternConverter; import org.elasticsearch.nativeaccess.NativeAccess; import org.elasticsearch.nativeaccess.VectorSimilarityFunctions; +import org.elasticsearch.simdvec.VectorSimilarityType; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -34,12 +35,15 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.rethrow; + +@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.NANOSECONDS) @State(Scope.Benchmark) @Warmup(iterations = 3, time = 1) @Measurement(iterations = 5, time = 1) -public class VectorScorerJDKFloat32Benchmark { +public class VectorScorerFloat32OperationBenchmark { static { NodeNamePatternConverter.setGlobalNodeName("foo"); @@ -60,6 +64,22 @@ public class VectorScorerJDKFloat32Benchmark { @Param({ "1", "128", "207", "256", "300", "512", "702", "1024", "1536", "2048" }) public int size; + @Param({ "COSINE", "DOT_PRODUCT", "EUCLIDEAN" }) + public VectorSimilarityType function; + + @FunctionalInterface + private interface LuceneFunction { + float run(float[] vec1, float[] vec2); + } + + @FunctionalInterface + private interface NativeFunction { + float run(MemorySegment vec1, MemorySegment vec2, int length); + } + + private LuceneFunction luceneImpl; + private NativeFunction nativeImpl; + @Setup(Level.Iteration) public void init() { ThreadLocalRandom random = ThreadLocalRandom.current(); @@ -79,6 +99,19 @@ public void init() { MemorySegment.copy(MemorySegment.ofArray(floatsA), LAYOUT_LE_FLOAT, 0L, nativeSegA, LAYOUT_LE_FLOAT, 0L, floatsA.length); nativeSegB = arena.allocate((long) floatsB.length * Float.BYTES); MemorySegment.copy(MemorySegment.ofArray(floatsB), LAYOUT_LE_FLOAT, 0L, nativeSegB, LAYOUT_LE_FLOAT, 0L, floatsB.length); + + luceneImpl = switch (function) { + case COSINE -> VectorUtil::cosine; + case DOT_PRODUCT -> VectorUtil::dotProduct; + case EUCLIDEAN -> VectorUtil::squareDistance; + default -> throw new UnsupportedOperationException("Not used"); + }; + nativeImpl = switch (function) { + case COSINE -> VectorScorerFloat32OperationBenchmark::cosineFloat32; + case DOT_PRODUCT -> VectorScorerFloat32OperationBenchmark::dotProductFloat32; + case EUCLIDEAN -> VectorScorerFloat32OperationBenchmark::squareDistanceFloat32; + default -> throw new UnsupportedOperationException("Not used"); + }; } @TearDown @@ -86,88 +119,26 @@ public void teardown() { arena.close(); } - // -- cosine - - @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public float cosineLucene() { - return VectorUtil.cosine(floatsA, floatsB); - } - - @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public float cosineLuceneWithCopy() { - // add a copy to better reflect what Lucene has to do to get the target vector on-heap - MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length); - return VectorUtil.cosine(floatsA, scratch); - } - - @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public float cosineNativeWithNativeSeg() { - return cosineFloat32(nativeSegA, nativeSegB, size); - } - - @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public float cosineNativeWithHeapSeg() { - return cosineFloat32(heapSegA, heapSegB, size); - } - - // -- dot product - - @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public float dotProductLucene() { - return VectorUtil.dotProduct(floatsA, floatsB); - } - - @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public float dotProductLuceneWithCopy() { - // add a copy to better reflect what Lucene has to do to get the target vector on-heap - MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length); - return VectorUtil.dotProduct(floatsA, scratch); - } - - @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public float dotProductNativeWithNativeSeg() { - return dotProductFloat32(nativeSegA, nativeSegB, size); - } - - @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public float dotProductNativeWithHeapSeg() { - return dotProductFloat32(heapSegA, heapSegB, size); - } - - // -- square distance - @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public float squareDistanceLucene() { - return VectorUtil.squareDistance(floatsA, floatsB); + public float lucene() { + return luceneImpl.run(floatsA, floatsB); } @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public float squareDistanceLuceneWithCopy() { + public float luceneWithCopy() { // add a copy to better reflect what Lucene has to do to get the target vector on-heap MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length); - return VectorUtil.squareDistance(floatsA, scratch); + return luceneImpl.run(floatsA, scratch); } @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public float squareDistanceNativeWithNativeSeg() { - return squareDistanceFloat32(nativeSegA, nativeSegB, size); + public float nativeWithNativeSeg() { + return nativeImpl.run(nativeSegA, nativeSegB, size); } @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public float squareDistanceNativeWithHeapSeg() { - return squareDistanceFloat32(heapSegA, heapSegB, size); + public float nativeWithHeapSeg() { + return nativeImpl.run(heapSegA, heapSegB, size); } static final VectorSimilarityFunctions vectorSimilarityFunctions = vectorSimilarityFunctions(); @@ -176,45 +147,27 @@ static VectorSimilarityFunctions vectorSimilarityFunctions() { return NativeAccess.instance().getVectorSimilarityFunctions().get(); } - float cosineFloat32(MemorySegment a, MemorySegment b, int length) { + static float cosineFloat32(MemorySegment a, MemorySegment b, int length) { try { return (float) vectorSimilarityFunctions.cosineHandleFloat32().invokeExact(a, b, length); } catch (Throwable e) { - if (e instanceof Error err) { - throw err; - } else if (e instanceof RuntimeException re) { - throw re; - } else { - throw new RuntimeException(e); - } + throw rethrow(e); } } - float dotProductFloat32(MemorySegment a, MemorySegment b, int length) { + static float dotProductFloat32(MemorySegment a, MemorySegment b, int length) { try { return (float) vectorSimilarityFunctions.dotProductHandleFloat32().invokeExact(a, b, length); } catch (Throwable e) { - if (e instanceof Error err) { - throw err; - } else if (e instanceof RuntimeException re) { - throw re; - } else { - throw new RuntimeException(e); - } + throw rethrow(e); } } - float squareDistanceFloat32(MemorySegment a, MemorySegment b, int length) { + static float squareDistanceFloat32(MemorySegment a, MemorySegment b, int length) { try { return (float) vectorSimilarityFunctions.squareDistanceHandleFloat32().invokeExact(a, b, length); } catch (Throwable e) { - if (e instanceof Error err) { - throw err; - } else if (e instanceof RuntimeException re) { - throw re; - } else { - throw new RuntimeException(e); - } + throw rethrow(e); } } } diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uBenchmark.java index e99fbdf6a6c11..1ca81fc146279 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uBenchmark.java @@ -47,6 +47,8 @@ import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.supportsHeapSegments; import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.vectorValues; import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.writeInt7VectorData; +import static org.elasticsearch.benchmark.vector.scorer.ScalarOperations.dotProduct; +import static org.elasticsearch.benchmark.vector.scorer.ScalarOperations.squareDistance; /** * Benchmark that compares various scalar quantized vector similarity function @@ -73,14 +75,8 @@ public class VectorScorerInt7uBenchmark { public int dims; public static int numVectors = 2; // there are only two vectors to compare - public enum Implementation { - SCALAR, - LUCENE, - NATIVE - } - @Param - public Implementation implementation; + public VectorImplementation implementation; @Param({ "DOT_PRODUCT", "EUCLIDEAN" }) public VectorSimilarityType function; @@ -112,10 +108,7 @@ private ScalarDotProduct( @Override public float score(int node) throws IOException { - int dotProduct = 0; - for (int i = 0; i < vec1.length; i++) { - dotProduct += vec1[i] * vec2[i]; - } + int dotProduct = dotProduct(vec1, vec2); float adjustedDistance = dotProduct * scoreCorrectionConstant + vec1CorrectionConstant + vec2CorrectionConstant; return (1 + adjustedDistance) / 2; } @@ -142,11 +135,7 @@ private ScalarSquareDistance(byte[] vec1, byte[] vec2, float scoreCorrectionCons @Override public float score(int node) throws IOException { - int squareDistance = 0; - for (int i = 0; i < vec1.length; i++) { - int diff = vec1[i] - vec2[i]; - squareDistance += diff * diff; - } + int squareDistance = squareDistance(vec1, vec2); float adjustedDistance = squareDistance * scoreCorrectionConstant; return 1 / (1f + adjustedDistance); } diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uBulkBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uBulkBenchmark.java index 8d299d8cfcbe1..2711192cadff4 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uBulkBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uBulkBenchmark.java @@ -52,6 +52,8 @@ import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.supportsHeapSegments; import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.vectorValues; import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.writeInt7VectorData; +import static org.elasticsearch.benchmark.vector.scorer.ScalarOperations.dotProduct; +import static org.elasticsearch.benchmark.vector.scorer.ScalarOperations.squareDistance; /** * Benchmark that compares bulk scoring of various scalar quantized vector similarity function @@ -85,14 +87,8 @@ public class VectorScorerInt7uBulkBenchmark { public int numVectors; public int numVectorsToScore; - public enum Implementation { - SCALAR, - LUCENE, - NATIVE - } - @Param - public Implementation implementation; + public VectorImplementation implementation; @Param({ "DOT_PRODUCT", "EUCLIDEAN" }) public VectorSimilarityType function; @@ -117,10 +113,7 @@ private ScalarDotProduct(QuantizedByteVectorValues values, float scoreCorrection public float score(int ordinal) throws IOException { var vec2 = values.vectorValue(ordinal); var vec2CorrectionConstant = values.getScoreCorrectionConstant(ordinal); - int dotProduct = 0; - for (int i = 0; i < queryVector.length; i++) { - dotProduct += queryVector[i] * vec2[i]; - } + int dotProduct = dotProduct(queryVector, vec2); float adjustedDistance = dotProduct * scoreCorrectionConstant + queryVectorCorrectionConstant + vec2CorrectionConstant; return (1 + adjustedDistance) / 2; } @@ -151,11 +144,7 @@ private ScalarSquareDistance(QuantizedByteVectorValues values, float scoreCorrec @Override public float score(int ordinal) throws IOException { var vec2 = values.vectorValue(ordinal); - int squareDistance = 0; - for (int i = 0; i < queryVector.length; i++) { - int diff = queryVector[i] - vec2[i]; - squareDistance += diff * diff; - } + int squareDistance = squareDistance(queryVector, vec2); float adjustedDistance = squareDistance * scoreCorrectionConstant; return 1 / (1f + adjustedDistance); } diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerJDKInt7uBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uOperationBenchmark.java similarity index 73% rename from benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerJDKInt7uBenchmark.java rename to benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uOperationBenchmark.java index d965aae29a7f9..4f528f9dbb8c2 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerJDKInt7uBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uOperationBenchmark.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.logging.NodeNamePatternConverter; import org.elasticsearch.nativeaccess.NativeAccess; import org.elasticsearch.nativeaccess.VectorSimilarityFunctions; +import org.elasticsearch.simdvec.VectorSimilarityType; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -29,15 +30,18 @@ import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; -import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.randomInt7BytesBetween; +import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.rethrow; + +@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.NANOSECONDS) @State(Scope.Benchmark) @Warmup(iterations = 3, time = 1) @Measurement(iterations = 5, time = 1) -public class VectorScorerJDKInt7uBenchmark { +public class VectorScorerInt7uOperationBenchmark { static { NodeNamePatternConverter.setGlobalNodeName("foo"); @@ -55,6 +59,9 @@ public class VectorScorerJDKInt7uBenchmark { @Param({ "1", "128", "207", "256", "300", "512", "702", "1024", "1536", "2048" }) public int size; + @Param({ "DOT_PRODUCT" }) + public VectorSimilarityType function; + @Setup(Level.Iteration) public void init() { byteArrayA = new byte[size]; @@ -67,9 +74,9 @@ public void init() { heapSegB = MemorySegment.ofArray(byteArrayB); arena = Arena.ofConfined(); - nativeSegA = arena.allocate((long) byteArrayA.length); + nativeSegA = arena.allocate(byteArrayA.length); MemorySegment.copy(MemorySegment.ofArray(byteArrayA), 0L, nativeSegA, 0L, byteArrayA.length); - nativeSegB = arena.allocate((long) byteArrayB.length); + nativeSegB = arena.allocate(byteArrayB.length); MemorySegment.copy(MemorySegment.ofArray(byteArrayB), 0L, nativeSegB, 0L, byteArrayB.length); } @@ -79,20 +86,17 @@ public void teardown() { } @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public int dotProductLucene() { + public int lucene() { return VectorUtil.dotProduct(byteArrayA, byteArrayB); } @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public int dotProductNativeWithNativeSeg() { + public int nativeWithNativeSeg() { return dotProduct7u(nativeSegA, nativeSegB, size); } @Benchmark - @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public int dotProductNativeWithHeapSeg() { + public int nativeWithHeapSeg() { return dotProduct7u(heapSegA, heapSegB, size); } @@ -106,24 +110,7 @@ int dotProduct7u(MemorySegment a, MemorySegment b, int length) { try { return (int) vectorSimilarityFunctions.dotProductHandle7u().invokeExact(a, b, length); } catch (Throwable e) { - if (e instanceof Error err) { - throw err; - } else if (e instanceof RuntimeException re) { - throw re; - } else { - throw new RuntimeException(e); - } - } - } - - // Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive). - static final byte MIN_INT7_VALUE = 0; - static final byte MAX_INT7_VALUE = 127; - - static void randomInt7BytesBetween(byte[] bytes) { - var random = ThreadLocalRandom.current(); - for (int i = 0, len = bytes.length; i < len;) { - bytes[i++] = (byte) random.nextInt(MIN_INT7_VALUE, MAX_INT7_VALUE + 1); + throw rethrow(e); } } } diff --git a/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerJDKFloat32BenchmarkTests.java b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerFloat32OperationBenchmarkTests.java similarity index 51% rename from benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerJDKFloat32BenchmarkTests.java rename to benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerFloat32OperationBenchmarkTests.java index 20c3802334df6..2f33857ad0300 100644 --- a/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerJDKFloat32BenchmarkTests.java +++ b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerFloat32OperationBenchmarkTests.java @@ -12,18 +12,22 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import org.apache.lucene.util.Constants; +import org.elasticsearch.simdvec.VectorSimilarityType; import org.elasticsearch.test.ESTestCase; +import org.junit.AssumptionViolatedException; import org.junit.BeforeClass; import org.openjdk.jmh.annotations.Param; import java.util.Arrays; -public class VectorScorerJDKFloat32BenchmarkTests extends ESTestCase { +public class VectorScorerFloat32OperationBenchmarkTests extends ESTestCase { - final double delta; - final int size; + private VectorSimilarityType function; + private final double delta; + private final int size; - public VectorScorerJDKFloat32BenchmarkTests(int size) { + public VectorScorerFloat32OperationBenchmarkTests(VectorSimilarityType function, int size) { + this.function = function; this.size = size; delta = 1e-3 * size; } @@ -37,56 +41,24 @@ static boolean supportsHeapSegments() { return Runtime.version().feature() >= 22; } - public void testCosine() { + public void test() { for (int i = 0; i < 100; i++) { - var bench = new VectorScorerJDKFloat32Benchmark(); + var bench = new VectorScorerFloat32OperationBenchmark(); + bench.function = function; bench.size = size; bench.init(); try { - float expected = cosineFloat32Scalar(bench.floatsA, bench.floatsB); - assertEquals(expected, bench.cosineLucene(), delta); - assertEquals(expected, bench.cosineLuceneWithCopy(), delta); - assertEquals(expected, bench.cosineNativeWithNativeSeg(), delta); + float expected = switch (function) { + case COSINE -> cosineFloat32Scalar(bench.floatsA, bench.floatsB); + case DOT_PRODUCT -> dotProductFloat32Scalar(bench.floatsA, bench.floatsB); + case EUCLIDEAN -> squareDistanceFloat32Scalar(bench.floatsA, bench.floatsB); + case MAXIMUM_INNER_PRODUCT -> throw new AssumptionViolatedException("Not tested"); + }; + assertEquals(expected, bench.lucene(), delta); + assertEquals(expected, bench.luceneWithCopy(), delta); + assertEquals(expected, bench.nativeWithNativeSeg(), delta); if (supportsHeapSegments()) { - assertEquals(expected, bench.cosineNativeWithHeapSeg(), delta); - } - } finally { - bench.teardown(); - } - } - } - - public void testDotProduct() { - for (int i = 0; i < 100; i++) { - var bench = new VectorScorerJDKFloat32Benchmark(); - bench.size = size; - bench.init(); - try { - float expected = dotProductFloat32Scalar(bench.floatsA, bench.floatsB); - assertEquals(expected, bench.dotProductLucene(), delta); - assertEquals(expected, bench.dotProductLuceneWithCopy(), delta); - assertEquals(expected, bench.dotProductNativeWithNativeSeg(), delta); - if (supportsHeapSegments()) { - assertEquals(expected, bench.dotProductNativeWithHeapSeg(), delta); - } - } finally { - bench.teardown(); - } - } - } - - public void testSquareDistance() { - for (int i = 0; i < 100; i++) { - var bench = new VectorScorerJDKFloat32Benchmark(); - bench.size = size; - bench.init(); - try { - float expected = squareDistanceFloat32Scalar(bench.floatsA, bench.floatsB); - assertEquals(expected, bench.squareDistanceLucene(), delta); - assertEquals(expected, bench.squareDistanceLuceneWithCopy(), delta); - assertEquals(expected, bench.squareDistanceNativeWithNativeSeg(), delta); - if (supportsHeapSegments()) { - assertEquals(expected, bench.squareDistanceNativeWithHeapSeg(), delta); + assertEquals(expected, bench.nativeWithHeapSeg(), delta); } } finally { bench.teardown(); @@ -97,8 +69,13 @@ public void testSquareDistance() { @ParametersFactory public static Iterable