From 2df0351931f529cb9b85719a5eb9502c114621d4 Mon Sep 17 00:00:00 2001 From: Chris Hegarty <62058229+ChrisHegarty@users.noreply.github.com> Date: Thu, 14 Aug 2025 12:30:34 +0100 Subject: [PATCH] Add remaining bulk float32 off-heap scoring similarities (#15037) This commit adds the remaining bulk float32 off-heap scoring similarities, cosine, euclidean, and max inner product. The changes in #14980 deliberately added only dot product, to avoid additional bloat on the PR and benchmarking. This PR now refactors things a little to allow for the remaining similarities to be added. Benchmarking will be carried out on them independently, as well as consideration for not negatively affecting dot product. relates #14980 --- lucene/CHANGES.txt | 2 +- .../jmh/VectorScorerFloat32Benchmark.java | 127 +++++++- .../jmh/TestVectorScorerFloat32Benchmark.java | 63 ++++ .../index/VectorSimilarityFunction.java | 3 +- .../org/apache/lucene/util/VectorUtil.java | 13 + ...ucene99MemorySegmentFlatVectorsScorer.java | 34 +-- ...ucene99MemorySegmentFloatVectorScorer.java | 210 ++++++++++--- ...emorySegmentFloatVectorScorerSupplier.java | 275 ++++++++++++++---- .../MemorySegmentBulkVectorOps.java | 263 +++++++++++++++++ 9 files changed, 875 insertions(+), 115 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index b3abbf854633..e8a2e3171acb 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -140,7 +140,7 @@ Optimizations * GITHUB#15001: Remove full integrity check from SortingStoredFieldsConsumer (Martijn van Groningen) -* GITHUB#14980: Add bulk off-heap scoring for float32 vectors (Chris Hegarty) +* GITHUB#14980, GITHUB#15037: Add bulk off-heap scoring for float32 vectors (Chris Hegarty) * GITHUB#15004: Wraps all iterator with likelyImpactsEnum under BlockMaxConjunctionBulkScorer. (Ge Song) diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerFloat32Benchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerFloat32Benchmark.java index 215f754c3fe2..92dc5b6de5a1 100644 --- a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerFloat32Benchmark.java +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerFloat32Benchmark.java @@ -16,7 +16,10 @@ */ package org.apache.lucene.benchmark.jmh; +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; import java.io.IOException; import java.nio.ByteBuffer; @@ -75,6 +78,11 @@ "-XX:+AlwaysPreTouch", "--add-modules=jdk.incubator.vector" }) +/** + * Benchmark to compare the performance of float32 vector scoring using the default and optimized + * scorers. While there are benchmark methods for each of the similarities, it is often most useful + * to compare equivalent subsets, e.g. .*dot.* + */ public class VectorScorerFloat32Benchmark { @Param({"1024"}) @@ -92,8 +100,8 @@ public class VectorScorerFloat32Benchmark { Directory dir; IndexInput in; KnnVectorValues values; - UpdateableRandomVectorScorer defDotScorer; - UpdateableRandomVectorScorer optDotScorer; + UpdateableRandomVectorScorer defDotScorer, defCosScorer, defEucScorer, defMipScorer; + UpdateableRandomVectorScorer optDotScorer, optCosScorer, optEucScorer, optMipScorer; @Setup(Level.Trial) public void setup() throws IOException { @@ -121,12 +129,24 @@ public void perIterationInit() throws IOException { values = vectorValues(size, numVectors, in, DOT_PRODUCT); var def = DefaultFlatVectorScorer.INSTANCE; defDotScorer = def.getRandomVectorScorerSupplier(DOT_PRODUCT, values.copy()).scorer(); + defCosScorer = def.getRandomVectorScorerSupplier(COSINE, values.copy()).scorer(); + defEucScorer = def.getRandomVectorScorerSupplier(EUCLIDEAN, values.copy()).scorer(); + defMipScorer = def.getRandomVectorScorerSupplier(MAXIMUM_INNER_PRODUCT, values.copy()).scorer(); defDotScorer.setScoringOrdinal(targetOrd); + defCosScorer.setScoringOrdinal(targetOrd); + defEucScorer.setScoringOrdinal(targetOrd); + defMipScorer.setScoringOrdinal(targetOrd); // optimized scorer var opt = FlatVectorScorerUtil.getLucene99FlatVectorsScorer(); optDotScorer = opt.getRandomVectorScorerSupplier(DOT_PRODUCT, values.copy()).scorer(); + optCosScorer = opt.getRandomVectorScorerSupplier(COSINE, values.copy()).scorer(); + optEucScorer = opt.getRandomVectorScorerSupplier(EUCLIDEAN, values.copy()).scorer(); + optMipScorer = opt.getRandomVectorScorerSupplier(MAXIMUM_INNER_PRODUCT, values.copy()).scorer(); optDotScorer.setScoringOrdinal(targetOrd); + optCosScorer.setScoringOrdinal(targetOrd); + optEucScorer.setScoringOrdinal(targetOrd); + optMipScorer.setScoringOrdinal(targetOrd); List list = IntStream.range(0, numVectors).boxed().collect(Collectors.toList()); Collections.shuffle(list, random); @@ -146,10 +166,21 @@ public void teardown() throws IOException { } public void pollute(Random random) throws IOException { + // exercise various similarities to ensure they don't have negative effects, e.g., + // type pollution on virtual calls, etc. float[] vec = randomVector(size, random); var opt = FlatVectorScorerUtil.getLucene99FlatVectorsScorer(); var scorer = opt.getRandomVectorScorer(DOT_PRODUCT, values.copy(), vec); + for (int i = 0; i < 2; i++) { + dotProductOptScorer(); + dotProductOptBulkScore(); + cosineOptScorer(); + cosineDefaultBulk(); + euclideanOptScorer(); + euclideanOptBulkScore(); + mipOptScorer(); + mipOptBulkScore(); for (int v = 0; v < numVectorsToScore; v++) { scores[v] = scorer.score(indices[v]); } @@ -157,6 +188,8 @@ public void pollute(Random random) throws IOException { } } + // -- dot product + @Benchmark public float[] dotProductDefault() throws IOException { for (int v = 0; v < numVectorsToScore; v++) { @@ -185,6 +218,96 @@ public float[] dotProductOptBulkScore() throws IOException { return scores; } + // -- euclidean + + @Benchmark + public float[] euclideanDefault() throws IOException { + for (int v = 0; v < numVectorsToScore; v++) { + scores[v] = defEucScorer.score(indices[v]); + } + return scores; + } + + @Benchmark + public float[] euclideanDefaultBulk() throws IOException { + defEucScorer.bulkScore(indices, scores, indices.length); + return scores; + } + + @Benchmark + public float[] euclideanOptScorer() throws IOException { + for (int v = 0; v < numVectorsToScore; v++) { + scores[v] = optEucScorer.score(indices[v]); + } + return scores; + } + + @Benchmark + public float[] euclideanOptBulkScore() throws IOException { + optEucScorer.bulkScore(indices, scores, indices.length); + return scores; + } + + // -- euclidean + + @Benchmark + public float[] cosineDefault() throws IOException { + for (int v = 0; v < numVectorsToScore; v++) { + scores[v] = defCosScorer.score(indices[v]); + } + return scores; + } + + @Benchmark + public float[] cosineDefaultBulk() throws IOException { + defCosScorer.bulkScore(indices, scores, indices.length); + return scores; + } + + @Benchmark + public float[] cosineOptScorer() throws IOException { + for (int v = 0; v < numVectorsToScore; v++) { + scores[v] = optCosScorer.score(indices[v]); + } + return scores; + } + + @Benchmark + public float[] cosineOptBulkScore() throws IOException { + optCosScorer.bulkScore(indices, scores, indices.length); + return scores; + } + + // -- max inner product + + @Benchmark + public float[] mipDefault() throws IOException { + for (int v = 0; v < numVectorsToScore; v++) { + scores[v] = defMipScorer.score(indices[v]); + } + return scores; + } + + @Benchmark + public float[] mipDefaultBulk() throws IOException { + defMipScorer.bulkScore(indices, scores, indices.length); + return scores; + } + + @Benchmark + public float[] mipOptScorer() throws IOException { + for (int v = 0; v < numVectorsToScore; v++) { + scores[v] = optMipScorer.score(indices[v]); + } + return scores; + } + + @Benchmark + public float[] mipOptBulkScore() throws IOException { + optMipScorer.bulkScore(indices, scores, indices.length); + return scores; + } + static float[] randomVector(int dims, Random random) { float[] fa = new float[dims]; for (int i = 0; i < dims; ++i) { diff --git a/lucene/benchmark-jmh/src/test/org/apache/lucene/benchmark/jmh/TestVectorScorerFloat32Benchmark.java b/lucene/benchmark-jmh/src/test/org/apache/lucene/benchmark/jmh/TestVectorScorerFloat32Benchmark.java index 765ef0455a3c..2417159e1836 100644 --- a/lucene/benchmark-jmh/src/test/org/apache/lucene/benchmark/jmh/TestVectorScorerFloat32Benchmark.java +++ b/lucene/benchmark-jmh/src/test/org/apache/lucene/benchmark/jmh/TestVectorScorerFloat32Benchmark.java @@ -65,4 +65,67 @@ public void testDotProduct() throws IOException { actualScores = ArrayUtil.copyArray(bench.scores); assertArrayEquals(expectedScores, actualScores, delta); } + + public void testCosine() throws IOException { + Arrays.fill(bench.scores, 0.0f); + bench.cosineDefault(); + var expectedScores = ArrayUtil.copyArray(bench.scores); + + Arrays.fill(bench.scores, 0.0f); + bench.cosineDefaultBulk(); + var bulkScores = ArrayUtil.copyArray(bench.scores); + assertArrayEquals(expectedScores, bulkScores, delta); + + Arrays.fill(bench.scores, 0.0f); + bench.cosineOptScorer(); + var actualScores = ArrayUtil.copyArray(bench.scores); + assertArrayEquals(expectedScores, actualScores, delta); + + Arrays.fill(bench.scores, 0.0f); + bench.cosineOptBulkScore(); + actualScores = ArrayUtil.copyArray(bench.scores); + assertArrayEquals(expectedScores, actualScores, delta); + } + + public void testEuclidean() throws IOException { + Arrays.fill(bench.scores, 0.0f); + bench.euclideanDefault(); + var expectedScores = ArrayUtil.copyArray(bench.scores); + + Arrays.fill(bench.scores, 0.0f); + bench.euclideanDefaultBulk(); + var bulkScores = ArrayUtil.copyArray(bench.scores); + assertArrayEquals(expectedScores, bulkScores, delta); + + Arrays.fill(bench.scores, 0.0f); + bench.euclideanOptScorer(); + var actualScores = ArrayUtil.copyArray(bench.scores); + assertArrayEquals(expectedScores, actualScores, delta); + + Arrays.fill(bench.scores, 0.0f); + bench.euclideanOptBulkScore(); + actualScores = ArrayUtil.copyArray(bench.scores); + assertArrayEquals(expectedScores, actualScores, delta); + } + + public void testMip() throws IOException { + Arrays.fill(bench.scores, 0.0f); + bench.mipDefault(); + var expectedScores = ArrayUtil.copyArray(bench.scores); + + Arrays.fill(bench.scores, 0.0f); + bench.mipDefaultBulk(); + var bulkScores = ArrayUtil.copyArray(bench.scores); + assertArrayEquals(expectedScores, bulkScores, delta); + + Arrays.fill(bench.scores, 0.0f); + bench.mipOptScorer(); + var actualScores = ArrayUtil.copyArray(bench.scores); + assertArrayEquals(expectedScores, actualScores, delta); + + Arrays.fill(bench.scores, 0.0f); + bench.mipOptBulkScore(); + actualScores = ArrayUtil.copyArray(bench.scores); + assertArrayEquals(expectedScores, actualScores, delta); + } } diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java index f905eeae24c4..a692d917e606 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java @@ -19,6 +19,7 @@ import static org.apache.lucene.util.VectorUtil.cosine; import static org.apache.lucene.util.VectorUtil.dotProduct; import static org.apache.lucene.util.VectorUtil.dotProductScore; +import static org.apache.lucene.util.VectorUtil.normalizeDistanceToUnitInterval; import static org.apache.lucene.util.VectorUtil.normalizeToUnitInterval; import static org.apache.lucene.util.VectorUtil.scaleMaxInnerProductScore; import static org.apache.lucene.util.VectorUtil.squareDistance; @@ -34,7 +35,7 @@ public enum VectorSimilarityFunction { EUCLIDEAN { @Override public float compare(float[] v1, float[] v2) { - return 1 / (1 + squareDistance(v1, v2)); + return normalizeDistanceToUnitInterval(squareDistance(v1, v2)); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 04a3ddeffaf5..4f64aa8b816a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -322,6 +322,19 @@ public static float normalizeToUnitInterval(float value) { return Math.max((1 + value) / 2, 0); } + /** + * Maps a non-negative squared distance to a similarity score in the range (0, 1]. + * + *

Uses the transformation: {@code similarity = 1 / (1 + squaredDistance)}. Smaller distances + * yield scores closer to 1; larger distances approach 0. + * + * @param squaredDistance squared Euclidean distance (must be ≥ 0) + * @return similarity score in (0, 1] + */ + public static float normalizeDistanceToUnitInterval(float squaredDistance) { + return 1.0f / (1.0f + squaredDistance); + } + /** * Checks if a float vector only has finite components. * diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java index 456747e563f5..8f60de6cda13 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java @@ -50,15 +50,13 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( private RandomVectorScorerSupplier getFloatScoringSupplier( FloatVectorValues vectorValues, VectorSimilarityFunction similarityType) throws IOException { - if (similarityType == VectorSimilarityFunction.DOT_PRODUCT) { // dot product for now - if (vectorValues instanceof HasIndexSlice sliceableValues - && sliceableValues.getSlice() != null) { - var scorer = - Lucene99MemorySegmentFloatVectorScorerSupplier.create( - similarityType, sliceableValues.getSlice(), vectorValues); - if (scorer.isPresent()) { - return scorer.get(); - } + if (vectorValues instanceof HasIndexSlice sliceableValues + && sliceableValues.getSlice() != null) { + var scorer = + Lucene99MemorySegmentFloatVectorScorerSupplier.create( + similarityType, sliceableValues.getSlice(), vectorValues); + if (scorer.isPresent()) { + return scorer.get(); } } return delegate.getRandomVectorScorerSupplier(similarityType, vectorValues); @@ -87,16 +85,14 @@ public RandomVectorScorer getRandomVectorScorer( VectorSimilarityFunction similarityType, KnnVectorValues vectorValues, float[] target) throws IOException { checkDimensions(target.length, vectorValues.dimension()); - if (similarityType == VectorSimilarityFunction.DOT_PRODUCT) { // just for now - if (vectorValues instanceof FloatVectorValues fvv - && fvv instanceof HasIndexSlice floatVectorValues - && floatVectorValues.getSlice() != null) { - var scorer = - Lucene99MemorySegmentFloatVectorScorer.create( - similarityType, floatVectorValues.getSlice(), fvv, target); - if (scorer.isPresent()) { - return scorer.get(); - } + if (vectorValues instanceof FloatVectorValues fvv + && fvv instanceof HasIndexSlice floatVectorValues + && floatVectorValues.getSlice() != null) { + var scorer = + Lucene99MemorySegmentFloatVectorScorer.create( + similarityType, floatVectorValues.getSlice(), fvv, target); + if (scorer.isPresent()) { + return scorer.get(); } } return delegate.getRandomVectorScorer(similarityType, vectorValues, target); diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorer.java index 49e07f36a64b..96d84b40bc9b 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorer.java @@ -16,8 +16,6 @@ */ package org.apache.lucene.internal.vectorization; -import static org.apache.lucene.util.VectorUtil.normalizeToUnitInterval; - import java.io.IOException; import java.lang.foreign.MemorySegment; import java.util.Optional; @@ -26,6 +24,7 @@ import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; abstract sealed class Lucene99MemorySegmentFloatVectorScorer @@ -35,16 +34,14 @@ abstract sealed class Lucene99MemorySegmentFloatVectorScorer final int vectorByteSize; final MemorySegment seg; final float[] query; + final float[] scratchScores = new float[4]; /** * Return an optional whose value, if present, is the scorer. Otherwise, an empty optional is * returned. */ public static Optional create( - VectorSimilarityFunction type, - IndexInput input, - FloatVectorValues values, - float[] queryVector) + VectorSimilarityFunction type, IndexInput input, FloatVectorValues values, float[] query) throws IOException { input = FilterIndexInput.unwrapOnlyTest(input); MemorySegment seg; @@ -54,21 +51,20 @@ public static Optional create( } checkInvariants(values.size(), values.getVectorByteLength(), input); return switch (type) { - case COSINE -> Optional.empty(); // of(new CosineScorer(msInput, values, queryVector)); - case DOT_PRODUCT -> Optional.of(new DotProductScorer(seg, values, queryVector)); - case EUCLIDEAN -> Optional.empty(); // of(new EuclideanScorer(msInput, values, queryVector)); - case MAXIMUM_INNER_PRODUCT -> - Optional.empty(); // of(new MaxInnerProductScorer(msInput, values, queryVector)); + case COSINE -> Optional.of(new CosineScorer(seg, values, query)); + case DOT_PRODUCT -> Optional.of(new DotProductScorer(seg, values, query)); + case EUCLIDEAN -> Optional.of(new EuclideanScorer(seg, values, query)); + case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductScorer(seg, values, query)); }; } Lucene99MemorySegmentFloatVectorScorer( - MemorySegment seg, FloatVectorValues values, float[] queryVector) { + MemorySegment seg, FloatVectorValues values, float[] query) { super(values); this.values = values; this.seg = seg; this.vectorByteSize = values.getVectorByteLength(); - this.query = queryVector; + this.query = query; } static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { @@ -83,13 +79,85 @@ final void checkOrdinal(int ord) { } } + @Override + public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + int i = 0; + final int limit = numNodes & ~3; + for (; i < limit; i += 4) { + long offset1 = (long) nodes[i] * vectorByteSize; + long offset2 = (long) nodes[i + 1] * vectorByteSize; + long offset3 = (long) nodes[i + 2] * vectorByteSize; + long offset4 = (long) nodes[i + 3] * vectorByteSize; + vectorOp(seg, scratchScores, offset1, offset2, offset3, offset4, query.length); + scores[i + 0] = normalizeRawScore(scratchScores[0]); + scores[i + 1] = normalizeRawScore(scratchScores[1]); + scores[i + 2] = normalizeRawScore(scratchScores[2]); + scores[i + 3] = normalizeRawScore(scratchScores[3]); + } + // Handle remaining 1–3 nodes in bulk (if any) + int remaining = numNodes - i; + if (remaining > 0) { + long addr1 = (long) nodes[i] * vectorByteSize; + long addr2 = (remaining > 1) ? (long) nodes[i + 1] * vectorByteSize : addr1; + long addr3 = (remaining > 2) ? (long) nodes[i + 2] * vectorByteSize : addr1; + vectorOp(seg, scratchScores, addr1, addr2, addr3, addr3, query.length); + scores[i] = normalizeRawScore(scratchScores[0]); + if (remaining > 1) scores[i + 1] = normalizeRawScore(scratchScores[1]); + if (remaining > 2) scores[i + 2] = normalizeRawScore(scratchScores[2]); + } + } + + abstract void vectorOp( + MemorySegment seg, + float[] scores, + long node1Offset, + long node2Offset, + long node3Offset, + long node4Offset, + int elementCount); + + abstract float normalizeRawScore(float value); + + static final class CosineScorer extends Lucene99MemorySegmentFloatVectorScorer { + + static final MemorySegmentBulkVectorOps.Cosine COS_OPS = + MemorySegmentBulkVectorOps.COS_INSTANCE; + + CosineScorer(MemorySegment seg, FloatVectorValues values, float[] query) { + super(seg, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + // just delegates to existing scorer that copies on-heap + return VectorSimilarityFunction.COSINE.compare(query, values.vectorValue(node)); + } + + @Override + void vectorOp( + MemorySegment seg, + float[] scores, + long node1Offset, + long node2Offset, + long node3Offset, + long node4Offset, + int elementCount) { + COS_OPS.cosineBulk( + seg, scores, query, node1Offset, node2Offset, node3Offset, node4Offset, elementCount); + } + + @Override + float normalizeRawScore(float rawScore) { + return VectorUtil.normalizeToUnitInterval(rawScore); + } + } + static final class DotProductScorer extends Lucene99MemorySegmentFloatVectorScorer { static final MemorySegmentBulkVectorOps.DotProduct DOT_OPS = MemorySegmentBulkVectorOps.DOT_INSTANCE; - final float[] scratchScores = new float[4]; - DotProductScorer(MemorySegment input, FloatVectorValues values, float[] query) { super(input, values, query); } @@ -102,32 +170,92 @@ public float score(int node) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) { - int i = 0; - final int limit = numNodes & ~3; - for (; i < limit; i += 4) { - long offset1 = (long) nodes[i] * vectorByteSize; - long offset2 = (long) nodes[i + 1] * vectorByteSize; - long offset3 = (long) nodes[i + 2] * vectorByteSize; - long offset4 = (long) nodes[i + 3] * vectorByteSize; - DOT_OPS.dotProductBulk( - seg, scratchScores, query, offset1, offset2, offset3, offset4, query.length); - scores[i + 0] = normalizeToUnitInterval(scratchScores[0]); - scores[i + 1] = normalizeToUnitInterval(scratchScores[1]); - scores[i + 2] = normalizeToUnitInterval(scratchScores[2]); - scores[i + 3] = normalizeToUnitInterval(scratchScores[3]); - } - // Handle remaining 1–3 nodes in bulk (if any) - int remaining = numNodes - i; - if (remaining > 0) { - long addr1 = (long) nodes[i] * vectorByteSize; - long addr2 = (remaining > 1) ? (long) nodes[i + 1] * vectorByteSize : addr1; - long addr3 = (remaining > 2) ? (long) nodes[i + 2] * vectorByteSize : addr1; - DOT_OPS.dotProductBulk(seg, scratchScores, query, addr1, addr2, addr3, addr1, query.length); - scores[i] = normalizeToUnitInterval(scratchScores[0]); - if (remaining > 1) scores[i + 1] = normalizeToUnitInterval(scratchScores[1]); - if (remaining > 2) scores[i + 2] = normalizeToUnitInterval(scratchScores[2]); - } + void vectorOp( + MemorySegment seg, + float[] scores, + long node1Offset, + long node2Offset, + long node3Offset, + long node4Offset, + int elementCount) { + DOT_OPS.dotProductBulk( + seg, scores, query, node1Offset, node2Offset, node3Offset, node4Offset, elementCount); + } + + @Override + float normalizeRawScore(float rawScore) { + return VectorUtil.normalizeToUnitInterval(rawScore); + } + } + + static final class EuclideanScorer extends Lucene99MemorySegmentFloatVectorScorer { + + static final MemorySegmentBulkVectorOps.SqrDistance SQR_OPS = + MemorySegmentBulkVectorOps.SQR_INSTANCE; + + EuclideanScorer(MemorySegment seg, FloatVectorValues values, float[] query) { + super(seg, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + // just delegates to existing scorer that copies on-heap + return VectorSimilarityFunction.EUCLIDEAN.compare(query, values.vectorValue(node)); + } + + @Override + void vectorOp( + MemorySegment seg, + float[] scores, + long node1Offset, + long node2Offset, + long node3Offset, + long node4Offset, + int elementCount) { + SQR_OPS.sqrDistanceBulk( + seg, scores, query, node1Offset, node2Offset, node3Offset, node4Offset, elementCount); + } + + @Override + float normalizeRawScore(float rawScore) { + return VectorUtil.normalizeDistanceToUnitInterval(rawScore); + } + } + + static final class MaxInnerProductScorer extends Lucene99MemorySegmentFloatVectorScorer { + + static final MemorySegmentBulkVectorOps.DotProduct DOT_OPS = + MemorySegmentBulkVectorOps.DOT_INSTANCE; + + MaxInnerProductScorer(MemorySegment seg, FloatVectorValues values, float[] query) { + super(seg, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + // just delegates to existing scorer that copies on-heap + return VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare( + query, values.vectorValue(node)); + } + + @Override + void vectorOp( + MemorySegment seg, + float[] scores, + long node1Offset, + long node2Offset, + long node3Offset, + long node4Offset, + int elementCount) { + DOT_OPS.dotProductBulk( + seg, scores, query, node1Offset, node2Offset, node3Offset, node4Offset, elementCount); + } + + @Override + float normalizeRawScore(float rawScore) { + return VectorUtil.scaleMaxInnerProductScore(rawScore); } } } diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorerSupplier.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorerSupplier.java index 61417452a1e3..48290fb7abb7 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorerSupplier.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorerSupplier.java @@ -16,8 +16,6 @@ */ package org.apache.lucene.internal.vectorization; -import static org.apache.lucene.util.VectorUtil.normalizeToUnitInterval; - import java.io.IOException; import java.lang.foreign.MemorySegment; import java.util.Optional; @@ -26,6 +24,7 @@ import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; @@ -53,11 +52,10 @@ static Optional create( } checkInvariants(values.size(), values.getVectorByteLength(), input); return switch (type) { - case COSINE -> Optional.empty(); // of(new CosineSupplier(msInput, values)); + case COSINE -> Optional.of(new CosineSupplier(seg, values)); case DOT_PRODUCT -> Optional.of(new DotProductSupplier(seg, values)); - case EUCLIDEAN -> Optional.empty(); // of(new EuclideanSupplier(msInput, values)); - case MAXIMUM_INNER_PRODUCT -> - Optional.empty(); // of(new MaxInnerProductSupplier(msInput, values)); + case EUCLIDEAN -> Optional.of(new EuclideanSupplier(seg, values)); + case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(seg, values)); }; } @@ -75,9 +73,47 @@ static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) } } - final void checkOrdinal(int ord) { - if (ord < 0 || ord >= maxOrd) { - throw new IllegalArgumentException("illegal ordinal: " + ord); + static final class CosineSupplier extends Lucene99MemorySegmentFloatVectorScorerSupplier { + + static final MemorySegmentBulkVectorOps.Cosine COS_OPS = + MemorySegmentBulkVectorOps.COS_INSTANCE; + + CosineSupplier(MemorySegment seg, FloatVectorValues values) { + super(seg, values); + } + + @Override + public UpdateableRandomVectorScorer scorer() { + return new AbstractBulkScorer(values) { + @Override + float vectorOp(MemorySegment seg, long q, long d, int elementCount) { + return COS_OPS.cosine(seg, q, d, dims); + } + + @Override + void vectorOp( + MemorySegment seg, + float[] scores, + long queryOffset, + long node1Offset, + long node2Offset, + long node3Offset, + long node4Offset, + int elementCount) { + COS_OPS.cosineBulk( + seg, scores, queryOffset, node1Offset, node2Offset, node3Offset, node4Offset, dims); + } + + @Override + float normalizeRawScore(float rawScore) { + return VectorUtil.normalizeToUnitInterval(rawScore); + } + }; + } + + @Override + public CosineSupplier copy() throws IOException { + return new CosineSupplier(seg, values.copy()); // TODO: check copy } } @@ -86,61 +122,35 @@ static final class DotProductSupplier extends Lucene99MemorySegmentFloatVectorSc static final MemorySegmentBulkVectorOps.DotProduct DOT_OPS = MemorySegmentBulkVectorOps.DOT_INSTANCE; - final float[] scratchScores = new float[4]; - DotProductSupplier(MemorySegment seg, FloatVectorValues values) { super(seg, values); } @Override public UpdateableRandomVectorScorer scorer() { - return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { - private int queryOrd; - + return new AbstractBulkScorer(values) { @Override - public float score(int node) throws IOException { - checkOrdinal(node); - long queryAddr = (long) queryOrd * vectorByteSize; - long addr = (long) node * vectorByteSize; - var raw = DOT_OPS.dotProduct(seg, queryAddr, addr, dims); - return normalizeToUnitInterval(raw); + float vectorOp(MemorySegment seg, long q, long d, int elementCount) { + return DOT_OPS.dotProduct(seg, q, d, dims); } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) { - // TODO checkOrdinal(node1 ....); - int i = 0; - long queryAddr = (long) queryOrd * vectorByteSize; - final int limit = numNodes & ~3; - for (; i < limit; i += 4) { - long offset1 = (long) nodes[i + 0] * vectorByteSize; - long offset2 = (long) nodes[i + 1] * vectorByteSize; - long offset3 = (long) nodes[i + 2] * vectorByteSize; - long offset4 = (long) nodes[i + 3] * vectorByteSize; - DOT_OPS.dotProductBulk( - seg, scratchScores, queryAddr, offset1, offset2, offset3, offset4, dims); - scores[i + 0] = normalizeToUnitInterval(scratchScores[0]); - scores[i + 1] = normalizeToUnitInterval(scratchScores[1]); - scores[i + 2] = normalizeToUnitInterval(scratchScores[2]); - scores[i + 3] = normalizeToUnitInterval(scratchScores[3]); - } - // Handle remaining 1–3 nodes in bulk (if any) - int remaining = numNodes - i; - if (remaining > 0) { - long addr1 = (long) nodes[i] * vectorByteSize; - long addr2 = (remaining > 1) ? (long) nodes[i + 1] * vectorByteSize : addr1; - long addr3 = (remaining > 2) ? (long) nodes[i + 2] * vectorByteSize : addr1; - DOT_OPS.dotProductBulk(seg, scratchScores, queryAddr, addr1, addr2, addr3, addr3, dims); - scores[i] = normalizeToUnitInterval(scratchScores[0]); - if (remaining > 1) scores[i + 1] = normalizeToUnitInterval(scratchScores[1]); - if (remaining > 2) scores[i + 2] = normalizeToUnitInterval(scratchScores[2]); - } + void vectorOp( + MemorySegment seg, + float[] scores, + long queryOffset, + long node1Offset, + long node2Offset, + long node3Offset, + long node4Offset, + int elementCount) { + DOT_OPS.dotProductBulk( + seg, scores, queryOffset, node1Offset, node2Offset, node3Offset, node4Offset, dims); } @Override - public void setScoringOrdinal(int node) { - checkOrdinal(node); - queryOrd = node; + float normalizeRawScore(float rawScore) { + return VectorUtil.normalizeToUnitInterval(rawScore); } }; } @@ -150,4 +160,167 @@ public DotProductSupplier copy() throws IOException { return new DotProductSupplier(seg, values); } } + + static final class EuclideanSupplier extends Lucene99MemorySegmentFloatVectorScorerSupplier { + + static final MemorySegmentBulkVectorOps.SqrDistance SQR_OPS = + MemorySegmentBulkVectorOps.SQR_INSTANCE; + + EuclideanSupplier(MemorySegment seg, FloatVectorValues values) { + super(seg, values); + } + + @Override + public UpdateableRandomVectorScorer scorer() { + return new AbstractBulkScorer(values) { + @Override + float vectorOp(MemorySegment seg, long q, long d, int elementCount) { + return SQR_OPS.sqrDistance(seg, q, d, dims); + } + + @Override + void vectorOp( + MemorySegment seg, + float[] scores, + long queryOffset, + long node1Offset, + long node2Offset, + long node3Offset, + long node4Offset, + int elementCount) { + SQR_OPS.sqrDistanceBulk( + seg, scores, queryOffset, node1Offset, node2Offset, node3Offset, node4Offset, dims); + } + + @Override + float normalizeRawScore(float rawScore) { + return VectorUtil.normalizeDistanceToUnitInterval(rawScore); + } + }; + } + + @Override + public EuclideanSupplier copy() throws IOException { + return new EuclideanSupplier(seg, values); // TODO: need to copy ? + } + } + + static final class MaxInnerProductSupplier + extends Lucene99MemorySegmentFloatVectorScorerSupplier { + + static final MemorySegmentBulkVectorOps.DotProduct DOT_OPS = + MemorySegmentBulkVectorOps.DOT_INSTANCE; + + MaxInnerProductSupplier(MemorySegment seg, FloatVectorValues values) { + super(seg, values); + } + + @Override + public UpdateableRandomVectorScorer scorer() { + return new AbstractBulkScorer(values) { + @Override + float vectorOp(MemorySegment seg, long q, long d, int elementCount) { + return DOT_OPS.dotProduct(seg, q, d, dims); + } + + @Override + void vectorOp( + MemorySegment seg, + float[] scores, + long queryOffset, + long node1Offset, + long node2Offset, + long node3Offset, + long node4Offset, + int elementCount) { + DOT_OPS.dotProductBulk( + seg, scores, queryOffset, node1Offset, node2Offset, node3Offset, node4Offset, dims); + } + + @Override + float normalizeRawScore(float rawScore) { + return VectorUtil.scaleMaxInnerProductScore(rawScore); + } + }; + } + + @Override + public MaxInnerProductSupplier copy() throws IOException { + return new MaxInnerProductSupplier(seg, values); + } + } + + abstract class AbstractBulkScorer + extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer { + private int queryOrd; + final float[] scratchScores = new float[4]; + + AbstractBulkScorer(FloatVectorValues values) { + super(values); + } + + final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + abstract float vectorOp(MemorySegment seg, long q, long d, int elementCount); + + abstract void vectorOp( + MemorySegment seg, + float[] scores, + long queryOffset, + long node1Offset, + long node2Offset, + long node3Offset, + long node4Offset, + int elementCount); + + abstract float normalizeRawScore(float rawScore); + + @Override + public float score(int node) { + checkOrdinal(node); + long queryAddr = (long) queryOrd * vectorByteSize; + long addr = (long) node * vectorByteSize; + var raw = vectorOp(seg, queryAddr, addr, dims); + return normalizeRawScore(raw); + } + + @Override + public void bulkScore(int[] nodes, float[] scores, int numNodes) { + int i = 0; + long queryAddr = (long) queryOrd * vectorByteSize; + final int limit = numNodes & ~3; + for (; i < limit; i += 4) { + long offset1 = (long) nodes[i] * vectorByteSize; + long offset2 = (long) nodes[i + 1] * vectorByteSize; + long offset3 = (long) nodes[i + 2] * vectorByteSize; + long offset4 = (long) nodes[i + 3] * vectorByteSize; + vectorOp(seg, scratchScores, queryAddr, offset1, offset2, offset3, offset4, dims); + scores[i + 0] = normalizeRawScore(scratchScores[0]); + scores[i + 1] = normalizeRawScore(scratchScores[1]); + scores[i + 2] = normalizeRawScore(scratchScores[2]); + scores[i + 3] = normalizeRawScore(scratchScores[3]); + } + // Handle remaining 1–3 nodes in bulk (if any) + int remaining = numNodes - i; + if (remaining > 0) { + long addr1 = (long) nodes[i] * vectorByteSize; + long addr2 = (remaining > 1) ? (long) nodes[i + 1] * vectorByteSize : addr1; + long addr3 = (remaining > 2) ? (long) nodes[i + 2] * vectorByteSize : addr1; + vectorOp(seg, scratchScores, queryAddr, addr1, addr2, addr3, addr1, dims); + scores[i] = normalizeRawScore(scratchScores[0]); + if (remaining > 1) scores[i + 1] = normalizeRawScore(scratchScores[1]); + if (remaining > 2) scores[i + 2] = normalizeRawScore(scratchScores[2]); + } + } + + @Override + public void setScoringOrdinal(int node) { + checkOrdinal(node); + queryOrd = node; + } + } } diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/MemorySegmentBulkVectorOps.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/MemorySegmentBulkVectorOps.java index 20aefce48318..0447389d7b81 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/MemorySegmentBulkVectorOps.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/MemorySegmentBulkVectorOps.java @@ -36,6 +36,8 @@ public final class MemorySegmentBulkVectorOps { static final ValueLayout.OfFloat LAYOUT_LE_FLOAT = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(LE); public static final DotProduct DOT_INSTANCE = new DotProduct(); + public static final Cosine COS_INSTANCE = new Cosine(); + public static final SqrDistance SQR_INSTANCE = new SqrDistance(); private MemorySegmentBulkVectorOps() {} @@ -144,4 +146,265 @@ private void dotProductBulkImpl( scores[3] = sum4; } } + + // -- cosine + + public static final class Cosine { + + private Cosine() {} + + public void cosineBulk( + MemorySegment dataSeg, + float[] scores, + float[] q, + long d1, + long d2, + long d3, + long d4, + int elementCount) { + cosineBulkImpl(dataSeg, scores, q, -1L, d1, d2, d3, d4, elementCount); + } + + public void cosineBulk( + MemorySegment seg, + float[] scores, + long q, + long d1, + long d2, + long d3, + long d4, + int elementCount) { + cosineBulkImpl(seg, scores, null, q, d1, d2, d3, d4, elementCount); + } + + public float cosine(MemorySegment seg, long q, long d, int elementCount) { + int i = 0; + FloatVector sv = FloatVector.zero(FLOAT_SPECIES); + FloatVector qvNorm = FloatVector.zero(FLOAT_SPECIES); + FloatVector dvNorm = FloatVector.zero(FLOAT_SPECIES); + final int limit = FLOAT_SPECIES.loopBound(elementCount); + for (; i < limit; i += FLOAT_SPECIES.length()) { + final long offset = (long) i * Float.BYTES; + FloatVector qv = FloatVector.fromMemorySegment(FLOAT_SPECIES, seg, q + offset, LE); + FloatVector dv = FloatVector.fromMemorySegment(FLOAT_SPECIES, seg, d + offset, LE); + sv = fma(qv, dv, sv); + qvNorm = fma(qv, qv, qvNorm); + dvNorm = fma(dv, dv, dvNorm); + } + float sum = sv.reduceLanes(VectorOperators.ADD); + float qNorm = qvNorm.reduceLanes(VectorOperators.ADD); + float dNorm = dvNorm.reduceLanes(VectorOperators.ADD); + + for (; i < elementCount; i++) { + final long offset = (long) i * Float.BYTES; + final float qValue = seg.get(LAYOUT_LE_FLOAT, q + offset); + final float dValue = seg.get(LAYOUT_LE_FLOAT, d + offset); + sum = fma(qValue, dValue, sum); + qNorm = fma(qValue, qValue, qNorm); + dNorm = fma(dValue, dValue, dNorm); + } + return (float) (sum / Math.sqrt((double) qNorm * (double) dNorm)); + } + + private void cosineBulkImpl( + MemorySegment seg, + float[] scores, + float[] qArray, + long qOffset, + long d1, + long d2, + long d3, + long d4, + int elementCount) { + int i = 0; + FloatVector sv1 = FloatVector.zero(FLOAT_SPECIES); + FloatVector sv2 = FloatVector.zero(FLOAT_SPECIES); + FloatVector sv3 = FloatVector.zero(FLOAT_SPECIES); + FloatVector sv4 = FloatVector.zero(FLOAT_SPECIES); + + FloatVector qvNorm = FloatVector.zero(FLOAT_SPECIES); + FloatVector dv1Norm = FloatVector.zero(FLOAT_SPECIES); + FloatVector dv2Norm = FloatVector.zero(FLOAT_SPECIES); + FloatVector dv3Norm = FloatVector.zero(FLOAT_SPECIES); + FloatVector dv4Norm = FloatVector.zero(FLOAT_SPECIES); + + final int limit = FLOAT_SPECIES.loopBound(elementCount); + for (; i < limit; i += FLOAT_SPECIES.length()) { + final long offset = (long) i * Float.BYTES; + FloatVector dv1 = FloatVector.fromMemorySegment(FLOAT_SPECIES, seg, d1 + offset, LE); + FloatVector dv2 = FloatVector.fromMemorySegment(FLOAT_SPECIES, seg, d2 + offset, LE); + FloatVector dv3 = FloatVector.fromMemorySegment(FLOAT_SPECIES, seg, d3 + offset, LE); + FloatVector dv4 = FloatVector.fromMemorySegment(FLOAT_SPECIES, seg, d4 + offset, LE); + FloatVector qv; + if (qOffset == -1L) { + qv = FloatVector.fromArray(FLOAT_SPECIES, qArray, i); + } else { + qv = FloatVector.fromMemorySegment(FLOAT_SPECIES, seg, qOffset + offset, LE); + } + qvNorm = fma(qv, qv, qvNorm); + dv1Norm = fma(dv1, dv1, dv1Norm); + sv1 = fma(qv, dv1, sv1); + dv2Norm = fma(dv2, dv2, dv2Norm); + sv2 = fma(qv, dv2, sv2); + dv3Norm = fma(dv3, dv3, dv3Norm); + sv3 = fma(qv, dv3, sv3); + dv4Norm = fma(dv4, dv4, dv4Norm); + sv4 = fma(qv, dv4, sv4); + } + float sum1 = sv1.reduceLanes(VectorOperators.ADD); + float sum2 = sv2.reduceLanes(VectorOperators.ADD); + float sum3 = sv3.reduceLanes(VectorOperators.ADD); + float sum4 = sv4.reduceLanes(VectorOperators.ADD); + float qNorm = qvNorm.reduceLanes(VectorOperators.ADD); + float d1Norm = dv1Norm.reduceLanes(VectorOperators.ADD); + float d2Norm = dv2Norm.reduceLanes(VectorOperators.ADD); + float d3Norm = dv3Norm.reduceLanes(VectorOperators.ADD); + float d4Norm = dv4Norm.reduceLanes(VectorOperators.ADD); + + for (; i < elementCount; i++) { + final long offset = (long) i * Float.BYTES; + final float qValue; + if (qOffset == -1L) { + qValue = qArray[i]; + } else { + qValue = seg.get(LAYOUT_LE_FLOAT, qOffset + offset); + } + final float d1Value = seg.get(LAYOUT_LE_FLOAT, d1 + offset); + final float d2Value = seg.get(LAYOUT_LE_FLOAT, d2 + offset); + final float d3Value = seg.get(LAYOUT_LE_FLOAT, d3 + offset); + final float d4Value = seg.get(LAYOUT_LE_FLOAT, d4 + offset); + sum1 = fma(qValue, d1Value, sum1); + sum2 = fma(qValue, d2Value, sum2); + sum3 = fma(qValue, d3Value, sum3); + sum4 = fma(qValue, d4Value, sum4); + qNorm = fma(qValue, qValue, qNorm); + d1Norm = fma(d1Value, d1Value, d1Norm); + d2Norm = fma(d2Value, d2Value, d2Norm); + d3Norm = fma(d3Value, d3Value, d3Norm); + d4Norm = fma(d4Value, d4Value, d4Norm); + } + scores[0] = (float) (sum1 / Math.sqrt((double) qNorm * (double) d1Norm)); + scores[1] = (float) (sum2 / Math.sqrt((double) qNorm * (double) d2Norm)); + scores[2] = (float) (sum3 / Math.sqrt((double) qNorm * (double) d3Norm)); + scores[3] = (float) (sum4 / Math.sqrt((double) qNorm * (double) d4Norm)); + } + } + + // -- square distance + + public static final class SqrDistance { + + private SqrDistance() {} + + public void sqrDistanceBulk( + MemorySegment dataSeg, + float[] scores, + float[] q, + long d1, + long d2, + long d3, + long d4, + int elementCount) { + sqrDistanceBulkImpl(dataSeg, scores, q, -1L, d1, d2, d3, d4, elementCount); + } + + public void sqrDistanceBulk( + MemorySegment seg, + float[] scores, + long q, + long d1, + long d2, + long d3, + long d4, + int elementCount) { + sqrDistanceBulkImpl(seg, scores, null, q, d1, d2, d3, d4, elementCount); + } + + public float sqrDistance(MemorySegment seg, long q, long d, int elementCount) { + int i = 0; + FloatVector sv = FloatVector.zero(FLOAT_SPECIES); + final int limit = FLOAT_SPECIES.loopBound(elementCount); + for (; i < limit; i += FLOAT_SPECIES.length()) { + final long offset = (long) i * Float.BYTES; + FloatVector qv = FloatVector.fromMemorySegment(FLOAT_SPECIES, seg, q + offset, LE); + FloatVector dv = FloatVector.fromMemorySegment(FLOAT_SPECIES, seg, d + offset, LE); + FloatVector diff = qv.sub(dv); + sv = fma(diff, diff, sv); + } + float score = sv.reduceLanes(VectorOperators.ADD); + + for (; i < elementCount; i++) { + final long offset = (long) i * Float.BYTES; + float diff = seg.get(LAYOUT_LE_FLOAT, q + offset) - seg.get(LAYOUT_LE_FLOAT, d + offset); + score = fma(diff, diff, score); + } + return score; + } + + private void sqrDistanceBulkImpl( + MemorySegment seg, + float[] scores, + float[] qArray, + long qOffset, + long d1, + long d2, + long d3, + long d4, + int elementCount) { + int i = 0; + FloatVector sv1 = FloatVector.zero(FLOAT_SPECIES); + FloatVector sv2 = FloatVector.zero(FLOAT_SPECIES); + FloatVector sv3 = FloatVector.zero(FLOAT_SPECIES); + FloatVector sv4 = FloatVector.zero(FLOAT_SPECIES); + + final int limit = FLOAT_SPECIES.loopBound(elementCount); + for (; i < limit; i += FLOAT_SPECIES.length()) { + final long offset = (long) i * Float.BYTES; + FloatVector dv1 = FloatVector.fromMemorySegment(FLOAT_SPECIES, seg, d1 + offset, LE); + FloatVector dv2 = FloatVector.fromMemorySegment(FLOAT_SPECIES, seg, d2 + offset, LE); + FloatVector dv3 = FloatVector.fromMemorySegment(FLOAT_SPECIES, seg, d3 + offset, LE); + FloatVector dv4 = FloatVector.fromMemorySegment(FLOAT_SPECIES, seg, d4 + offset, LE); + FloatVector qv; + if (qOffset == -1L) { + qv = FloatVector.fromArray(FLOAT_SPECIES, qArray, i); + } else { + qv = FloatVector.fromMemorySegment(FLOAT_SPECIES, seg, qOffset + offset, LE); + } + FloatVector diff1 = qv.sub(dv1); + FloatVector diff2 = qv.sub(dv2); + FloatVector diff3 = qv.sub(dv3); + FloatVector diff4 = qv.sub(dv4); + sv1 = fma(diff1, diff1, sv1); + sv2 = fma(diff2, diff2, sv2); + sv3 = fma(diff3, diff3, sv3); + sv4 = fma(diff4, diff4, sv4); + } + float sum1 = sv1.reduceLanes(VectorOperators.ADD); + float sum2 = sv2.reduceLanes(VectorOperators.ADD); + float sum3 = sv3.reduceLanes(VectorOperators.ADD); + float sum4 = sv4.reduceLanes(VectorOperators.ADD); + + for (; i < elementCount; i++) { + final long offset = (long) i * Float.BYTES; + final float qValue; + if (qOffset == -1L) { + qValue = qArray[i]; + } else { + qValue = seg.get(LAYOUT_LE_FLOAT, qOffset + offset); + } + float diff1 = qValue - seg.get(LAYOUT_LE_FLOAT, d1 + offset); + float diff2 = qValue - seg.get(LAYOUT_LE_FLOAT, d2 + offset); + float diff3 = qValue - seg.get(LAYOUT_LE_FLOAT, d3 + offset); + float diff4 = qValue - seg.get(LAYOUT_LE_FLOAT, d4 + offset); + sum1 = fma(diff1, diff1, sum1); + sum2 = fma(diff2, diff2, sum2); + sum3 = fma(diff3, diff3, sum3); + sum4 = fma(diff4, diff4, sum4); + } + scores[0] = sum1; + scores[1] = sum2; + scores[2] = sum3; + scores[3] = sum4; + } + } }