diff --git a/docs/changelog/141718.yaml b/docs/changelog/141718.yaml new file mode 100644 index 0000000000000..2c957dd6bc0ba --- /dev/null +++ b/docs/changelog/141718.yaml @@ -0,0 +1,5 @@ +area: Vector Search +issues: [] +pr: 141718 +summary: Enable zero-copy SIMD vector scoring on searchable snapshots (frozen tier) +type: enhancement diff --git a/libs/core/src/main/java/org/elasticsearch/core/DirectAccessInput.java b/libs/core/src/main/java/org/elasticsearch/core/DirectAccessInput.java new file mode 100644 index 0000000000000..173e7b0317770 --- /dev/null +++ b/libs/core/src/main/java/org/elasticsearch/core/DirectAccessInput.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.core; + +import java.io.IOException; +import java.nio.ByteBuffer; + +/** + * An optional interface that an IndexInput can implement to provide direct + * access to the underlying data as a {@link ByteBuffer}. This enables + * zero-copy access to memory-mapped data for SIMD-accelerated vector scoring. + * + *

The byte buffer is passed to the caller's action and is only valid for + * the duration of that call. All ref-counting and resource releases, if any, + * is handled internally. + */ +public interface DirectAccessInput { + + /** + * If a direct byte buffer view is available for the given range, passes it + * to {@code action} and returns {@code true}. Otherwise returns + * {@code false} without invoking the action. + * + *

The byte buffer is read-only and valid only for the duration of the + * action. Callers must not retain references to it after the action returns. + * + * @param offset the byte offset within the input + * @param length the number of bytes requested + * @param action the action to perform with the byte buffer + * @return {@code true} if a buffer was available and the action was invoked + */ + boolean withByteBufferSlice(long offset, long length, CheckedConsumer action) throws IOException; +} diff --git a/libs/native/src/main/java/module-info.java b/libs/native/src/main/java/module-info.java index dc74ecb5c6329..cb22d59fb54b4 100644 --- a/libs/native/src/main/java/module-info.java +++ b/libs/native/src/main/java/module-info.java @@ -20,8 +20,10 @@ to org.elasticsearch.server, org.elasticsearch.blobcache, + org.elasticsearch.searchablesnapshots, org.elasticsearch.simdvec, - org.elasticsearch.systemd; + org.elasticsearch.systemd, + org.elasticsearch.xpack.stateless; uses NativeLibraryProvider; diff --git a/libs/simdvec/src/main/java/module-info.java b/libs/simdvec/src/main/java/module-info.java index 8cf2bac6a9504..deafe7f482bdd 100644 --- a/libs/simdvec/src/main/java/module-info.java +++ b/libs/simdvec/src/main/java/module-info.java @@ -26,9 +26,10 @@ * at runtime is selected by the multi-release classloader. */ module org.elasticsearch.simdvec { + requires org.elasticsearch.base; + requires org.elasticsearch.logging; requires org.elasticsearch.nativeaccess; requires org.apache.lucene.core; - requires org.elasticsearch.logging; exports org.elasticsearch.simdvec to org.elasticsearch.server; } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/MemorySegmentAccessInputAccess.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/MemorySegmentAccessInputAccess.java index 85618edc8da42..29e11725e2e2b 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/MemorySegmentAccessInputAccess.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/MemorySegmentAccessInputAccess.java @@ -15,8 +15,10 @@ import java.util.Objects; public interface MemorySegmentAccessInputAccess { + /** Returns the underlying {@link MemorySegmentAccessInput}, or {@code null} if not available. */ MemorySegmentAccessInput get(); + /** Unwraps to the underlying {@link MemorySegmentAccessInput} if available, otherwise returns the input unchanged. */ static IndexInput unwrap(IndexInput input) { MemorySegmentAccessInput memorySeg = input instanceof MemorySegmentAccessInputAccess msaia ? msaia.get() : null; return Objects.requireNonNullElse((IndexInput) memorySeg, input); diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java index a4acb878acdee..80c1d4fb7eae7 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java @@ -33,6 +33,14 @@ public static ESVectorizationProvider getInstance() { /** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */ public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension, int bulkSize) throws IOException; + /** + * Create a new {@link ESNextOSQVectorsScorer} for the given {@link IndexInput}. + * The input should be unwrapped before calling this method. If the input is + * still a {@code FilterIndexInput} that does not implement + * {@code MemorySegmentAccessInput} or {@code DirectAccessInput}, an + * {@link IllegalArgumentException} is thrown. Non-wrapper inputs (e.g. + * {@code ByteBuffersIndexInput}) are accepted and use a heap-copy fallback. + */ public abstract ESNextOSQVectorsScorer newESNextOSQVectorsScorer( IndexInput input, byte queryBits, @@ -42,7 +50,10 @@ public abstract ESNextOSQVectorsScorer newESNextOSQVectorsScorer( int bulkSize ) throws IOException; - /** Create a new {@link ES92Int7VectorsScorer} for the given {@link IndexInput}. */ + /** + * Create a new {@link ES92Int7VectorsScorer} for the given {@link IndexInput}. + * See {@link #newESNextOSQVectorsScorer} for input type requirements. + */ public abstract ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension, int bulkSize) throws IOException; // visible for tests diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/IndexInputUtils.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/IndexInputUtils.java new file mode 100644 index 0000000000000..41985352a1af0 --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/IndexInputUtils.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.elasticsearch.core.CheckedFunction; +import org.elasticsearch.core.DirectAccessInput; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.util.function.IntFunction; + +/** + * Utility for obtaining a {@link MemorySegment} view of data in an + * {@link IndexInput} and passing it to a caller-supplied action. The + * segment may come from a {@link MemorySegmentAccessInput} (mmap), + * a direct {@link java.nio.ByteBuffer} view (e.g. blob-cache), or a + * heap copy as a last resort. + * + *

All resource management (ref-counting, buffer release) is handled + * internally — callers never see a closeable resource. + */ +public final class IndexInputUtils { + + private IndexInputUtils() {} + + /** + * Obtains a memory segment for the next {@code length} bytes of the + * index input, passes it to {@code action}, and returns the result. + * The position of the index input is advanced by {@code length}. + * + *

This method first tries to obtain a slice via + * {@link MemorySegmentAccessInput#segmentSliceOrNull}. If that + * returns {@code null}, it tries a direct {@link java.nio.ByteBuffer} + * view via {@link DirectAccessInput}. As a last resort it copies the + * data onto the heap using a byte array obtained from + * {@code scratchSupplier}. + * + *

The memory segment passed to {@code action} is valid only for + * the duration of the call. Callers must not retain references to it. + * + * @param in the index input positioned at the data to read + * @param length the number of bytes to read + * @param scratchSupplier supplies a byte array of at least the requested + * length, used only on the heap-copy fallback path + * @param action the function to apply to the memory segment + * @return the result of applying {@code action} + */ + public static R withSlice( + IndexInput in, + long length, + IntFunction scratchSupplier, + CheckedFunction action + ) throws IOException { + checkInputType(in); + if (in instanceof MemorySegmentAccessInput msai) { + long offset = in.getFilePointer(); + MemorySegment slice = msai.segmentSliceOrNull(offset, length); + if (slice != null) { + in.skipBytes(length); + return action.apply(slice); + } + } + if (in instanceof DirectAccessInput dai) { + long offset = in.getFilePointer(); + @SuppressWarnings("unchecked") + R[] result = (R[]) new Object[1]; + boolean available = dai.withByteBufferSlice(offset, length, bb -> { + in.skipBytes(length); + result[0] = action.apply(MemorySegment.ofBuffer(bb)); + }); + if (available) { + return result[0]; + } + } + return action.apply(copyOnHeap(in, Math.toIntExact(length), scratchSupplier)); + } + + /** + * Checks that a {@link FilterIndexInput} wrapper also implements + * {@link MemorySegmentAccessInput} or {@link DirectAccessInput}, + * so that zero-copy access is preserved through the wrapper chain. + */ + public static void checkInputType(IndexInput in) { + if (in instanceof FilterIndexInput && (in instanceof MemorySegmentAccessInput || in instanceof DirectAccessInput) == false) { + throw new IllegalArgumentException( + "IndexInput is a FilterIndexInput (" + + in.getClass().getName() + + ") that does not implement MemorySegmentAccessInput or DirectAccessInput. " + + "Ensure the wrapper implements DirectAccessInput or is unwrapped before constructing the scorer." + ); + } + } + + /** + * Reads the given number of bytes from the current position of the + * given IndexInput into a heap-backed memory segment. The returned + * segment is sliced to exactly {@code bytesToRead} bytes, even if + * the underlying array is larger. + */ + private static MemorySegment copyOnHeap(IndexInput in, int bytesToRead, IntFunction scratchSupplier) throws IOException { + byte[] buf = scratchSupplier.apply(bytesToRead); + in.readBytes(buf, 0, bytesToRead); + return MemorySegment.ofArray(buf).asSlice(0, bytesToRead); + } +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java index 117aad9a85290..c2a009473ce83 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java @@ -12,13 +12,12 @@ import org.apache.lucene.store.IndexInput; import java.io.IOException; -import java.lang.foreign.MemorySegment; /** Panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/ public final class MemorySegmentES92Int7VectorsScorer extends MemorySegmentES92PanamaInt7VectorsScorer { - public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, int bulkSize, MemorySegment memorySegment) { - super(in, dimensions, bulkSize, memorySegment); + public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, int bulkSize) { + super(in, dimensions, bulkSize); } @Override diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92PanamaInt7VectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92PanamaInt7VectorsScorer.java index 535c6a3b8eb89..51ce3b9d4a01c 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92PanamaInt7VectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92PanamaInt7VectorsScorer.java @@ -24,6 +24,7 @@ import java.io.IOException; import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; import java.nio.ByteOrder; import static java.nio.ByteOrder.LITTLE_ENDIAN; @@ -58,44 +59,55 @@ abstract class MemorySegmentES92PanamaInt7VectorsScorer extends ES92Int7VectorsS INT_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(VECTOR_BITSIZE)); } - protected final MemorySegment memorySegment; + private byte[] scratch; - protected MemorySegmentES92PanamaInt7VectorsScorer(IndexInput in, int dimensions, int bulkSize, MemorySegment memorySegment) { + protected MemorySegmentES92PanamaInt7VectorsScorer(IndexInput in, int dimensions, int bulkSize) { super(in, dimensions, bulkSize); - this.memorySegment = memorySegment; + IndexInputUtils.checkInputType(in); + } + + protected byte[] getScratch(int len) { + if (scratch == null || scratch.length < len) { + scratch = new byte[len]; + } + return scratch; } protected long panamaInt7DotProduct(byte[] q) throws IOException { assert dimensions == q.length; + return IndexInputUtils.withSlice(in, dimensions, this::getScratch, segment -> panamaInt7DotProductImpl(q, segment, dimensions)); + } + + private static long panamaInt7DotProductImpl(byte[] q, MemorySegment segment, int dimensions) { int i = 0; int res = 0; // only vectorize if we'll at least enter the loop a single time if (dimensions >= 16) { // compute vectorized dot product consistent with VPDPBUSD instruction if (VECTOR_BITSIZE >= 512) { - i += BYTE_SPECIES_128.loopBound(dimensions); - res += dotProductBody512(q, i); + int limit = BYTE_SPECIES_128.loopBound(dimensions); + res += dotProductBody512Impl(q, segment, 0, limit); + i = limit; } else if (VECTOR_BITSIZE == 256) { - i += BYTE_SPECIES_64.loopBound(dimensions); - res += dotProductBody256(q, i); + int limit = BYTE_SPECIES_64.loopBound(dimensions); + res += dotProductBody256Impl(q, segment, 0, limit); + i = limit; } else { // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" - i += BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length()); - res += dotProductBody128(q, i); + int limit = BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length()); + res += dotProductBody128Impl(q, segment, 0, limit); + i = limit; } - // scalar tail - while (i < dimensions) { - res += in.readByte() * q[i++]; - } - return res; - } else { - return super.int7DotProduct(q); } + // scalar tail + for (; i < dimensions; i++) { + res += segment.get(ValueLayout.JAVA_BYTE, i) * q[i]; + } + return res; } - private int dotProductBody512(byte[] q, int limit) throws IOException { + private static int dotProductBody512Impl(byte[] q, MemorySegment memorySegment, long offset, int limit) { IntVector acc = IntVector.zero(INT_SPECIES_512); - long offset = in.getFilePointer(); for (int i = 0; i < limit; i += BYTE_SPECIES_128.length()) { ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i); ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN); @@ -109,15 +121,11 @@ private int dotProductBody512(byte[] q, int limit) throws IOException { Vector prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0); acc = acc.add(prod32); } - - in.seek(offset + limit); // advance the input stream - // reduce return acc.reduceLanes(ADD); } - private int dotProductBody256(byte[] q, int limit) throws IOException { + private static int dotProductBody256Impl(byte[] q, MemorySegment memorySegment, long offset, int limit) { IntVector acc = IntVector.zero(INT_SPECIES_256); - long offset = in.getFilePointer(); for (int i = 0; i < limit; i += BYTE_SPECIES_64.length()) { ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); @@ -127,14 +135,11 @@ private int dotProductBody256(byte[] q, int limit) throws IOException { Vector vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0); acc = acc.add(va32.mul(vb32)); } - in.seek(offset + limit); - // reduce return acc.reduceLanes(ADD); } - private int dotProductBody128(byte[] q, int limit) throws IOException { + private static int dotProductBody128Impl(byte[] q, MemorySegment memorySegment, long offset, int limit) { IntVector acc = IntVector.zero(IntVector.SPECIES_128); - long offset = in.getFilePointer(); // 4 bytes at a time (re-loading half the vector each time!) for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) { // load 8 bytes @@ -149,109 +154,72 @@ private int dotProductBody128(byte[] q, int limit) throws IOException { // 32-bit add acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0)); } - in.seek(offset + limit); - // reduce return acc.reduceLanes(ADD); } protected void panamaInt7DotProductBulk(byte[] q, int count, float[] scores) throws IOException { assert dimensions == q.length; + IndexInputUtils.withSlice(in, (long) dimensions * count, this::getScratch, segment -> { + panamaInt7DotProductBulkImpl(q, segment, dimensions, count, scores); + return null; + }); + } + + private static void panamaInt7DotProductBulkImpl(byte[] q, MemorySegment memorySegment, int dimensions, int count, float[] scores) { // only vectorize if we'll at least enter the loop a single time if (dimensions >= 16) { // compute vectorized dot product consistent with VPDPBUSD instruction if (VECTOR_BITSIZE >= 512) { - dotProductBody512Bulk(q, count, scores); + dotProductBulkVectorized512(q, memorySegment, dimensions, count, scores); } else if (VECTOR_BITSIZE == 256) { - dotProductBody256Bulk(q, count, scores); + dotProductBulkVectorized256(q, memorySegment, dimensions, count, scores); } else { // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" - dotProductBody128Bulk(q, count, scores); + dotProductBulkVectorized128(q, memorySegment, dimensions, count, scores); } } else { - super.int7DotProductBulk(q, count, scores); + for (int iter = 0; iter < count; iter++) { + long base = (long) iter * dimensions; + long res = 0; + for (int i = 0; i < dimensions; i++) { + res += memorySegment.get(ValueLayout.JAVA_BYTE, base + i) * q[i]; + } + scores[iter] = res; + } } } - private void dotProductBody512Bulk(byte[] q, int count, float[] scores) throws IOException { + private static void dotProductBulkVectorized512(byte[] q, MemorySegment memorySegment, int dimensions, int count, float[] scores) { int limit = BYTE_SPECIES_128.loopBound(dimensions); for (int iter = 0; iter < count; iter++) { - IntVector acc = IntVector.zero(INT_SPECIES_512); - long offset = in.getFilePointer(); - int i = 0; - for (; i < limit; i += BYTE_SPECIES_128.length()) { - ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i); - ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN); - - // 16-bit multiply: avoid AVX-512 heavy multiply on zmm - Vector va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0); - Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0); - Vector prod16 = va16.mul(vb16); - - // 32-bit add - Vector prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0); - acc = acc.add(prod32); - } - - in.seek(offset + limit); // advance the input stream - // reduce - long res = acc.reduceLanes(ADD); - for (; i < dimensions; i++) { - res += in.readByte() * q[i]; + long base = (long) iter * dimensions; + long res = dotProductBody512Impl(q, memorySegment, base, limit); + for (int i = limit; i < dimensions; i++) { + res += memorySegment.get(ValueLayout.JAVA_BYTE, base + i) * q[i]; } scores[iter] = res; } } - private void dotProductBody256Bulk(byte[] q, int count, float[] scores) throws IOException { + private static void dotProductBulkVectorized256(byte[] q, MemorySegment memorySegment, int dimensions, int count, float[] scores) { int limit = BYTE_SPECIES_128.loopBound(dimensions); for (int iter = 0; iter < count; iter++) { - IntVector acc = IntVector.zero(INT_SPECIES_256); - long offset = in.getFilePointer(); - int i = 0; - for (; i < limit; i += BYTE_SPECIES_64.length()) { - ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); - ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); - - // 32-bit multiply and add into accumulator - Vector va32 = va8.convertShape(B2I, INT_SPECIES_256, 0); - Vector vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0); - acc = acc.add(va32.mul(vb32)); - } - in.seek(offset + limit); - // reduce - long res = acc.reduceLanes(ADD); - for (; i < dimensions; i++) { - res += in.readByte() * q[i]; + long base = (long) iter * dimensions; + long res = dotProductBody256Impl(q, memorySegment, base, limit); + for (int i = limit; i < dimensions; i++) { + res += memorySegment.get(ValueLayout.JAVA_BYTE, base + i) * q[i]; } scores[iter] = res; } } - private void dotProductBody128Bulk(byte[] q, int count, float[] scores) throws IOException { + private static void dotProductBulkVectorized128(byte[] q, MemorySegment memorySegment, int dimensions, int count, float[] scores) { int limit = BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length()); for (int iter = 0; iter < count; iter++) { - IntVector acc = IntVector.zero(IntVector.SPECIES_128); - long offset = in.getFilePointer(); - // 4 bytes at a time (re-loading half the vector each time!) - int i = 0; - for (; i < limit; i += ByteVector.SPECIES_64.length() >> 1) { - // load 8 bytes - ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); - ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); - - // process first "half" only: 16-bit multiply - Vector va16 = va8.convert(B2S, 0); - Vector vb16 = vb8.convert(B2S, 0); - Vector prod16 = va16.mul(vb16); - - // 32-bit add - acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0)); - } - in.seek(offset + limit); - // reduce - long res = acc.reduceLanes(ADD); - for (; i < dimensions; i++) { - res += in.readByte() * q[i]; + long base = (long) iter * dimensions; + long res = dotProductBody128Impl(q, memorySegment, base, limit); + for (int i = limit; i < dimensions; i++) { + res += memorySegment.get(ValueLayout.JAVA_BYTE, base + i) * q[i]; } scores[iter] = res; } @@ -267,30 +235,55 @@ protected void applyCorrectionsBulk( float[] scores, int bulkSize ) throws IOException { + IndexInputUtils.withSlice(in, 16L * bulkSize, this::getScratch, memorySegment -> { + applyCorrectionsBulkImpl( + memorySegment, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + dimensions + ); + return null; + }); + } + + private static void applyCorrectionsBulkImpl( + MemorySegment memorySegment, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores, + int bulkSize, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + int dimensions + ) { int limit = FLOAT_SPECIES.loopBound(bulkSize); int i = 0; - long offset = in.getFilePointer(); float ay = queryLowerInterval; float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE; float y1 = queryComponentSum; for (; i < limit; i += FLOAT_SPECIES.length()) { - var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); - var lx = FloatVector.fromMemorySegment( - FLOAT_SPECIES, - memorySegment, - offset + 4 * bulkSize + i * Float.BYTES, - ByteOrder.LITTLE_ENDIAN - ).sub(ax).mul(SEVEN_BIT_SCALE); + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var lx = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, 4 * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN) + .sub(ax) + .mul(SEVEN_BIT_SCALE); var targetComponentSums = IntVector.fromMemorySegment( INT_SPECIES, memorySegment, - offset + 8 * bulkSize + i * Integer.BYTES, + 8 * bulkSize + i * Integer.BYTES, ByteOrder.LITTLE_ENDIAN ).convert(VectorOperators.I2F, 0); var additionalCorrections = FloatVector.fromMemorySegment( FLOAT_SPECIES, memorySegment, - offset + 12 * bulkSize + i * Float.BYTES, + 12 * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN ); var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i); @@ -328,17 +321,11 @@ protected void applyCorrectionsBulk( var floatVectorMask = FLOAT_SPECIES.indexInRange(i, bulkSize); var intVectorMask = INT_SPECIES.indexInRange(i, bulkSize); - var ax = FloatVector.fromMemorySegment( - FLOAT_SPECIES, - memorySegment, - offset + (long) i * Float.BYTES, - ByteOrder.LITTLE_ENDIAN, - floatVectorMask - ); + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, (long) i * Float.BYTES, LITTLE_ENDIAN, floatVectorMask); var upper = FloatVector.fromMemorySegment( FLOAT_SPECIES, memorySegment, - offset + 4L * bulkSize + (long) i * Float.BYTES, + 4L * bulkSize + (long) i * Float.BYTES, ByteOrder.LITTLE_ENDIAN, floatVectorMask ); @@ -347,7 +334,7 @@ protected void applyCorrectionsBulk( var targetComponentSums = IntVector.fromMemorySegment( INT_SPECIES, memorySegment, - offset + 8L * bulkSize + (long) i * Integer.BYTES, + 8L * bulkSize + (long) i * Integer.BYTES, ByteOrder.LITTLE_ENDIAN, intVectorMask ).convert(VectorOperators.I2F, 0); @@ -355,7 +342,7 @@ protected void applyCorrectionsBulk( var additionalCorrections = FloatVector.fromMemorySegment( FLOAT_SPECIES, memorySegment, - offset + 12L * bulkSize + (long) i * Float.BYTES, + 12L * bulkSize + (long) i * Float.BYTES, ByteOrder.LITTLE_ENDIAN, floatVectorMask ); @@ -385,6 +372,6 @@ protected void applyCorrectionsBulk( } } } - in.seek(offset + 16L * bulkSize); } + } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java index 8a010c23abe99..ac79f585f08c5 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java @@ -40,6 +40,14 @@ public static ESVectorizationProvider getInstance() { /** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */ public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension, int bulkSize) throws IOException; + /** + * Create a new {@link ESNextOSQVectorsScorer} for the given {@link IndexInput}. + * The input should be unwrapped before calling this method. If the input is + * still a {@code FilterIndexInput} that does not implement + * {@code MemorySegmentAccessInput} or {@code DirectAccessInput}, an + * {@link IllegalArgumentException} is thrown. Non-wrapper inputs (e.g. + * {@code ByteBuffersIndexInput}) are accepted and use a heap-copy fallback. + */ public abstract ESNextOSQVectorsScorer newESNextOSQVectorsScorer( IndexInput input, byte queryBits, @@ -49,7 +57,10 @@ public abstract ESNextOSQVectorsScorer newESNextOSQVectorsScorer( int bulkSize ) throws IOException; - /** Create a new {@link ES92Int7VectorsScorer} for the given {@link IndexInput}. */ + /** + * Create a new {@link ES92Int7VectorsScorer} for the given {@link IndexInput}. + * See {@link #newESNextOSQVectorsScorer} for input type requirements. + */ public abstract ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension, int bulkSize) throws IOException; // visible for tests diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java index d8fa469c41c05..7d4124c4adf6a 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java @@ -18,6 +18,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.simdvec.internal.IndexInputUtils; import java.io.IOException; import java.lang.foreign.Arena; @@ -33,8 +34,8 @@ /** Panamized scorer for quantized vectors stored as a {@link MemorySegment}. */ final class MSBitToInt4ESNextOSQVectorsScorer extends MemorySegmentESNextOSQVectorsScorer.MemorySegmentScorer { - MSBitToInt4ESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize, MemorySegment memorySegment) { - super(in, dimensions, dataLength, bulkSize, memorySegment); + MSBitToInt4ESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize) { + super(in, dimensions, dataLength, bulkSize); } @Override @@ -56,9 +57,10 @@ public long quantizeScore(byte[] q) throws IOException { } private long nativeQuantizeScore(byte[] q) throws IOException { - long offset = in.getFilePointer(); - var datasetMemorySegment = memorySegment.asSlice(offset, length); + return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> nativeQuantizeScoreImpl(q, segment, length)); + } + private static long nativeQuantizeScoreImpl(byte[] q, MemorySegment datasetMemorySegment, int length) { final long qScore; if (SUPPORTS_HEAP_SEGMENTS) { var queryMemorySegment = MemorySegment.ofArray(q); @@ -70,29 +72,31 @@ private long nativeQuantizeScore(byte[] q) throws IOException { qScore = dotProductD1Q4(datasetMemorySegment, queryMemorySegment, length); } } - in.skipBytes(length); return qScore; } private long quantizeScore256(byte[] q) throws IOException { + return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> quantizeScore256Impl(q, segment, length)); + } + + private static long quantizeScore256Impl(byte[] q, MemorySegment memorySegment, int length) { long subRet0 = 0; long subRet1 = 0; long subRet2 = 0; long subRet3 = 0; int i = 0; - long offset = in.getFilePointer(); if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { int limit = ByteVector.SPECIES_256.loopBound(length); var sum0 = LongVector.zero(LONG_SPECIES_256); var sum1 = LongVector.zero(LONG_SPECIES_256); var sum2 = LongVector.zero(LONG_SPECIES_256); var sum3 = LongVector.zero(LONG_SPECIES_256); - for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) { + for (; i < limit; i += ByteVector.SPECIES_256.length()) { var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs(); var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs(); var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs(); - var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, i, ByteOrder.LITTLE_ENDIAN); sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); @@ -110,12 +114,12 @@ private long quantizeScore256(byte[] q) throws IOException { var sum2 = LongVector.zero(LONG_SPECIES_128); var sum3 = LongVector.zero(LONG_SPECIES_128); int limit = ByteVector.SPECIES_128.loopBound(length); - for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) { + for (; i < limit; i += ByteVector.SPECIES_128.length()) { var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs(); var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs(); var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs(); - var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN); sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); @@ -127,23 +131,22 @@ private long quantizeScore256(byte[] q) throws IOException { subRet3 += sum3.reduceLanes(VectorOperators.ADD); } // process scalar tail - in.seek(offset); for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); + final long value = memorySegment.get(LAYOUT_LE_LONG, i); subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); } for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); + final int value = memorySegment.get(LAYOUT_LE_INT, i); subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); } for (; i < length; i++) { - int dValue = in.readByte() & 0xFF; + final int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF; subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); @@ -153,20 +156,22 @@ private long quantizeScore256(byte[] q) throws IOException { } private long quantizeScore128(byte[] q) throws IOException { + return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> quantizeScore128Impl(q, segment, length)); + } + + private static long quantizeScore128Impl(byte[] q, MemorySegment memorySegment, int length) { long subRet0 = 0; long subRet1 = 0; long subRet2 = 0; long subRet3 = 0; int i = 0; - long offset = in.getFilePointer(); - var sum0 = IntVector.zero(INT_SPECIES_128); var sum1 = IntVector.zero(INT_SPECIES_128); var sum2 = IntVector.zero(INT_SPECIES_128); var sum3 = IntVector.zero(INT_SPECIES_128); int limit = ByteVector.SPECIES_128.loopBound(length); - for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) { - var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN); var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts(); var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts(); @@ -181,23 +186,22 @@ private long quantizeScore128(byte[] q) throws IOException { subRet2 += sum2.reduceLanes(VectorOperators.ADD); subRet3 += sum3.reduceLanes(VectorOperators.ADD); // process scalar tail - in.seek(offset); for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); + final long value = memorySegment.get(LAYOUT_LE_LONG, i); subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); } for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); + final int value = memorySegment.get(LAYOUT_LE_INT, i); subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); } for (; i < length; i++) { - int dValue = in.readByte() & 0xFF; + final int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF; subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); @@ -240,24 +244,29 @@ public boolean quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOE } private void nativeQuantizeScoreBulk(MemorySegment querySegment, int count, MemorySegment scoresSegment) throws IOException { - long initialOffset = in.getFilePointer(); var datasetLengthInBytes = (long) length * count; - MemorySegment datasetSegment = memorySegment.asSlice(initialOffset, datasetLengthInBytes); - - dotProductD1Q4Bulk(datasetSegment, querySegment, length, count, scoresSegment); - - in.skipBytes(datasetLengthInBytes); + IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, datasetSegment -> { + dotProductD1Q4Bulk(datasetSegment, querySegment, length, count, scoresSegment); + return null; + }); } private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException { + var datasetLengthInBytes = (long) length * count; + IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, segment -> { + quantizeScore128BulkImpl(q, segment, length, count, scores); + return null; + }); + } + + private static void quantizeScore128BulkImpl(byte[] q, MemorySegment memorySegment, int length, int count, float[] scores) { + int offset = 0; for (int iter = 0; iter < count; iter++) { long subRet0 = 0; long subRet1 = 0; long subRet2 = 0; long subRet3 = 0; int i = 0; - long offset = in.getFilePointer(); - var sum0 = IntVector.zero(INT_SPECIES_128); var sum1 = IntVector.zero(INT_SPECIES_128); var sum2 = IntVector.zero(INT_SPECIES_128); @@ -279,23 +288,22 @@ private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IO subRet2 += sum2.reduceLanes(VectorOperators.ADD); subRet3 += sum3.reduceLanes(VectorOperators.ADD); // process scalar tail - in.seek(offset); - for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES, offset += Long.BYTES) { + final long value = memorySegment.get(LAYOUT_LE_LONG, offset); subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); } - for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES, offset += Integer.BYTES) { + final int value = memorySegment.get(LAYOUT_LE_INT, offset); subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); } - for (; i < length; i++) { - int dValue = in.readByte() & 0xFF; + for (; i < length; i++, offset++) { + final int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, offset) & 0xFF; subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); @@ -306,13 +314,21 @@ private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IO } private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IOException { + var datasetLengthInBytes = (long) length * count; + IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, segment -> { + quantizeScore256BulkImpl(q, segment, length, count, scores); + return null; + }); + } + + private static void quantizeScore256BulkImpl(byte[] q, MemorySegment memorySegment, int length, int count, float[] scores) { + int offset = 0; for (int iter = 0; iter < count; iter++) { long subRet0 = 0; long subRet1 = 0; long subRet2 = 0; long subRet3 = 0; int i = 0; - long offset = in.getFilePointer(); if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { int limit = ByteVector.SPECIES_256.loopBound(length); var sum0 = LongVector.zero(LONG_SPECIES_256); @@ -359,23 +375,22 @@ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IO subRet3 += sum3.reduceLanes(VectorOperators.ADD); } // process scalar tail - in.seek(offset); - for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES, offset += Long.BYTES) { + final long value = memorySegment.get(LAYOUT_LE_LONG, offset); subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); } - for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES, offset += Integer.BYTES) { + final int value = memorySegment.get(LAYOUT_LE_INT, offset); subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); } - for (; i < length; i++) { - int dValue = in.readByte() & 0xFF; + for (; i < length; i++, offset++) { + final int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, offset) & 0xFF; subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); @@ -476,24 +491,25 @@ private float nativeApplyCorrectionsBulk( MemorySegment scoresSegment, int bulkSize ) throws IOException { - long offset = in.getFilePointer(); - - final float maxScore = ScoreCorrections.nativeApplyCorrectionsBulk( - similarityFunction, - memorySegment.asSlice(offset), - bulkSize, - dimensions, - queryLowerInterval, - queryUpperInterval, - queryComponentSum, - queryAdditionalCorrection, - FOUR_BIT_SCALE, - ONE_BIT_SCALE, - centroidDp, - scoresSegment + return IndexInputUtils.withSlice( + in, + 16L * bulkSize, + this::getScratch, + seg -> ScoreCorrections.nativeApplyCorrectionsBulk( + similarityFunction, + seg, + bulkSize, + dimensions, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + FOUR_BIT_SCALE, + ONE_BIT_SCALE, + centroidDp, + scoresSegment + ) ); - in.seek(offset + 16L * bulkSize); - return maxScore; } private float applyCorrections128Bulk( @@ -506,31 +522,59 @@ private float applyCorrections128Bulk( float[] scores, int bulkSize ) throws IOException { + return IndexInputUtils.withSlice( + in, + 16L * bulkSize, + this::getScratch, + seg -> applyCorrections128BulkImpl( + seg, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize, + queryLowerInterval, + queryUpperInterval, + queryComponentSum + ) + ); + } + + private float applyCorrections128BulkImpl( + MemorySegment memorySegment, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores, + int bulkSize, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum + ) { int limit = FLOAT_SPECIES_128.loopBound(bulkSize); int i = 0; - long offset = in.getFilePointer(); float ay = queryLowerInterval; float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; float y1 = queryComponentSum; float maxScore = Float.NEGATIVE_INFINITY; for (; i < limit; i += FLOAT_SPECIES_128.length()) { - var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); var lx = FloatVector.fromMemorySegment( FLOAT_SPECIES_128, memorySegment, - offset + 4L * bulkSize + i * Float.BYTES, + 4L * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN ).sub(ax); var targetComponentSums = IntVector.fromMemorySegment( INT_SPECIES_128, memorySegment, - offset + 8L * bulkSize + i * Integer.BYTES, + 8L * bulkSize + i * Integer.BYTES, ByteOrder.LITTLE_ENDIAN ).convert(VectorOperators.I2F, 0); var additionalCorrections = FloatVector.fromMemorySegment( FLOAT_SPECIES_128, memorySegment, - offset + 12L * bulkSize + i * Float.BYTES, + 12L * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN ); var qcDist = FloatVector.fromArray(FLOAT_SPECIES_128, scores, i); @@ -568,6 +612,7 @@ private float applyCorrections128Bulk( } if (limit < bulkSize) { maxScore = applyCorrectionsIndividually( + memorySegment, queryAdditionalCorrection, similarityFunction, centroidDp, @@ -575,14 +620,13 @@ private float applyCorrections128Bulk( scores, bulkSize, limit, - offset, + i, ay, ly, y1, maxScore ); } - in.seek(offset + 16L * bulkSize); return maxScore; } @@ -596,31 +640,59 @@ private float applyCorrections256Bulk( float[] scores, int bulkSize ) throws IOException { + return IndexInputUtils.withSlice( + in, + 16L * bulkSize, + this::getScratch, + memorySegment -> applyCorrections256BulkImpl( + memorySegment, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize, + queryLowerInterval, + queryUpperInterval, + queryComponentSum + ) + ); + } + + private float applyCorrections256BulkImpl( + MemorySegment memorySegment, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores, + int bulkSize, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum + ) { int limit = FLOAT_SPECIES_256.loopBound(bulkSize); int i = 0; - long offset = in.getFilePointer(); float ay = queryLowerInterval; float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; float y1 = queryComponentSum; float maxScore = Float.NEGATIVE_INFINITY; for (; i < limit; i += FLOAT_SPECIES_256.length()) { - var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); var lx = FloatVector.fromMemorySegment( FLOAT_SPECIES_256, memorySegment, - offset + 4L * bulkSize + i * Float.BYTES, + 4L * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN ).sub(ax); var targetComponentSums = IntVector.fromMemorySegment( INT_SPECIES_256, memorySegment, - offset + 8L * bulkSize + i * Integer.BYTES, + 8L * bulkSize + i * Integer.BYTES, ByteOrder.LITTLE_ENDIAN ).convert(VectorOperators.I2F, 0); var additionalCorrections = FloatVector.fromMemorySegment( FLOAT_SPECIES_256, memorySegment, - offset + 12L * bulkSize + i * Float.BYTES, + 12L * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN ); var qcDist = FloatVector.fromArray(FLOAT_SPECIES_256, scores, i); @@ -658,6 +730,7 @@ private float applyCorrections256Bulk( } if (limit < bulkSize) { maxScore = applyCorrectionsIndividually( + memorySegment, queryAdditionalCorrection, similarityFunction, centroidDp, @@ -665,14 +738,13 @@ private float applyCorrections256Bulk( scores, bulkSize, limit, - offset, + i, ay, ly, y1, maxScore ); } - in.seek(offset + 16L * bulkSize); return maxScore; } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSD7Q7ESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSD7Q7ESNextOSQVectorsScorer.java index 8ab4c52b1272e..604df54c78019 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSD7Q7ESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSD7Q7ESNextOSQVectorsScorer.java @@ -20,9 +20,9 @@ final class MSD7Q7ESNextOSQVectorsScorer extends MemorySegmentESNextOSQVectorsSc private final MemorySegmentES92Int7VectorsScorer int7Scorer; - MSD7Q7ESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize, MemorySegment memorySegment) { - super(in, dimensions, dataLength, bulkSize, memorySegment); - this.int7Scorer = new MemorySegmentES92Int7VectorsScorer(in, dimensions, bulkSize, memorySegment); + MSD7Q7ESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize) { + super(in, dimensions, dataLength, bulkSize); + this.int7Scorer = new MemorySegmentES92Int7VectorsScorer(in, dimensions, bulkSize); } @Override diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSDibitToInt4ESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSDibitToInt4ESNextOSQVectorsScorer.java index 57e48f83f1661..ce7ab96c936ee 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSDibitToInt4ESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSDibitToInt4ESNextOSQVectorsScorer.java @@ -18,6 +18,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.simdvec.internal.IndexInputUtils; import java.io.IOException; import java.lang.foreign.Arena; @@ -33,8 +34,8 @@ /** Panamized scorer for quantized vectors stored as a {@link MemorySegment}. */ final class MSDibitToInt4ESNextOSQVectorsScorer extends MemorySegmentESNextOSQVectorsScorer.MemorySegmentScorer { - MSDibitToInt4ESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize, MemorySegment memorySegment) { - super(in, dimensions, dataLength, bulkSize, memorySegment); + MSDibitToInt4ESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize) { + super(in, dimensions, dataLength, bulkSize); } @Override @@ -56,9 +57,10 @@ public long quantizeScore(byte[] q) throws IOException { } private long nativeQuantizeScore(byte[] q) throws IOException { - long offset = in.getFilePointer(); - var datasetMemorySegment = memorySegment.asSlice(offset, length); + return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> nativeQuantizeScoreImpl(q, segment, length)); + } + private static long nativeQuantizeScoreImpl(byte[] q, MemorySegment datasetMemorySegment, int length) { final long qScore; if (SUPPORTS_HEAP_SEGMENTS) { var queryMemorySegment = MemorySegment.ofArray(q); @@ -70,7 +72,6 @@ private long nativeQuantizeScore(byte[] q) throws IOException { qScore = dotProductD2Q4(datasetMemorySegment, queryMemorySegment, length); } } - in.skipBytes(length); return qScore; } @@ -87,25 +88,28 @@ private long quantizeScore128DibitToInt4(byte[] q) throws IOException { } private long quantizeScore256(byte[] q) throws IOException { + int size = length / 2; + return IndexInputUtils.withSlice(in, size, this::getScratch, segment -> quantizeScore256Impl(q, segment, size)); + } + + private static long quantizeScore256Impl(byte[] q, MemorySegment memorySegment, int size) { long subRet0 = 0; long subRet1 = 0; long subRet2 = 0; long subRet3 = 0; int i = 0; - long offset = in.getFilePointer(); - int size = length / 2; if (size >= ByteVector.SPECIES_256.vectorByteSize() * 2) { int limit = ByteVector.SPECIES_256.loopBound(size); var sum0 = LongVector.zero(LONG_SPECIES_256); var sum1 = LongVector.zero(LONG_SPECIES_256); var sum2 = LongVector.zero(LONG_SPECIES_256); var sum3 = LongVector.zero(LONG_SPECIES_256); - for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) { + for (; i < limit; i += ByteVector.SPECIES_256.length()) { var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size).reinterpretAsLongs(); var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size * 2).reinterpretAsLongs(); var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size * 3).reinterpretAsLongs(); - var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, i, ByteOrder.LITTLE_ENDIAN); sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); @@ -123,12 +127,12 @@ private long quantizeScore256(byte[] q) throws IOException { var sum2 = LongVector.zero(LONG_SPECIES_128); var sum3 = LongVector.zero(LONG_SPECIES_128); int limit = ByteVector.SPECIES_128.loopBound(size); - for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) { + for (; i < limit; i += ByteVector.SPECIES_128.length()) { var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size).reinterpretAsLongs(); var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 2).reinterpretAsLongs(); var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 3).reinterpretAsLongs(); - var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN); sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); @@ -140,23 +144,22 @@ private long quantizeScore256(byte[] q) throws IOException { subRet3 += sum3.reduceLanes(VectorOperators.ADD); } // process scalar tail - in.seek(offset); for (final int upperBound = size & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); + final long value = memorySegment.get(LAYOUT_LE_LONG, i); subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + size) & value); subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * size) & value); subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * size) & value); } for (final int upperBound = size & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); + final int value = memorySegment.get(LAYOUT_LE_INT, i); subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + size) & value); subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * size) & value); subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * size) & value); } for (; i < size; i++) { - int dValue = in.readByte() & 0xFF; + int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF; subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); subRet1 += Integer.bitCount((q[i + size] & dValue) & 0xFF); subRet2 += Integer.bitCount((q[i + 2 * size] & dValue) & 0xFF); @@ -166,21 +169,24 @@ private long quantizeScore256(byte[] q) throws IOException { } private long quantizeScore128(byte[] q) throws IOException { + int size = length / 2; + return IndexInputUtils.withSlice(in, size, this::getScratch, segment -> quantizeScore128Impl(q, segment, size)); + } + + private static long quantizeScore128Impl(byte[] q, MemorySegment memorySegment, int size) { long subRet0 = 0; long subRet1 = 0; long subRet2 = 0; long subRet3 = 0; int i = 0; - long offset = in.getFilePointer(); var sum0 = IntVector.zero(INT_SPECIES_128); var sum1 = IntVector.zero(INT_SPECIES_128); var sum2 = IntVector.zero(INT_SPECIES_128); var sum3 = IntVector.zero(INT_SPECIES_128); - int size = length / 2; int limit = ByteVector.SPECIES_128.loopBound(size); - for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) { - var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN); var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size).reinterpretAsInts(); var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 2).reinterpretAsInts(); @@ -195,23 +201,22 @@ private long quantizeScore128(byte[] q) throws IOException { subRet2 += sum2.reduceLanes(VectorOperators.ADD); subRet3 += sum3.reduceLanes(VectorOperators.ADD); // process scalar tail - in.seek(offset); for (final int upperBound = size & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); + final long value = memorySegment.get(LAYOUT_LE_LONG, i); subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + size) & value); subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * size) & value); subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * size) & value); } for (final int upperBound = size & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); + final int value = memorySegment.get(LAYOUT_LE_INT, i); subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + size) & value); subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * size) & value); subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * size) & value); } for (; i < size; i++) { - int dValue = in.readByte() & 0xFF; + int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF; subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); subRet1 += Integer.bitCount((q[i + size] & dValue) & 0xFF); subRet2 += Integer.bitCount((q[i + 2 * size] & dValue) & 0xFF); @@ -254,13 +259,11 @@ public boolean quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOE } private void nativeQuantizeScoreBulk(MemorySegment queryMemorySegment, int count, MemorySegment scoresSegment) throws IOException { - long initialOffset = in.getFilePointer(); var datasetLengthInBytes = (long) length * count; - MemorySegment datasetSegment = memorySegment.asSlice(initialOffset, datasetLengthInBytes); - - dotProductD2Q4Bulk(datasetSegment, queryMemorySegment, length, count, scoresSegment); - - in.skipBytes(datasetLengthInBytes); + IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, datasetSegment -> { + dotProductD2Q4Bulk(datasetSegment, queryMemorySegment, length, count, scoresSegment); + return null; + }); } private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException { @@ -366,24 +369,25 @@ private float nativeApplyCorrectionsBulk( MemorySegment scoresSegment, int bulkSize ) throws IOException { - long offset = in.getFilePointer(); - - final float maxScore = ScoreCorrections.nativeApplyCorrectionsBulk( - similarityFunction, - memorySegment.asSlice(offset), - bulkSize, - dimensions, - queryLowerInterval, - queryUpperInterval, - queryComponentSum, - queryAdditionalCorrection, - FOUR_BIT_SCALE, - TWO_BIT_SCALE, - centroidDp, - scoresSegment + return IndexInputUtils.withSlice( + in, + 16L * bulkSize, + this::getScratch, + seg -> ScoreCorrections.nativeApplyCorrectionsBulk( + similarityFunction, + seg, + bulkSize, + dimensions, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + FOUR_BIT_SCALE, + TWO_BIT_SCALE, + centroidDp, + scoresSegment + ) ); - in.seek(offset + 16L * bulkSize); - return maxScore; } private float applyCorrections128Bulk( @@ -396,31 +400,59 @@ private float applyCorrections128Bulk( float[] scores, int bulkSize ) throws IOException { + return IndexInputUtils.withSlice( + in, + 16L * bulkSize, + this::getScratch, + memorySegment -> applyCorrections128BulkImpl( + memorySegment, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize, + queryLowerInterval, + queryUpperInterval, + queryComponentSum + ) + ); + } + + private float applyCorrections128BulkImpl( + MemorySegment memorySegment, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores, + int bulkSize, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum + ) { int limit = FLOAT_SPECIES_128.loopBound(bulkSize); int i = 0; - long offset = in.getFilePointer(); float ay = queryLowerInterval; float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; float y1 = queryComponentSum; float maxScore = Float.NEGATIVE_INFINITY; for (; i < limit; i += FLOAT_SPECIES_128.length()) { - var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); var lx = FloatVector.fromMemorySegment( FLOAT_SPECIES_128, memorySegment, - offset + 4L * bulkSize + i * Float.BYTES, + 4L * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN ).sub(ax).mul(TWO_BIT_SCALE); var targetComponentSums = IntVector.fromMemorySegment( INT_SPECIES_128, memorySegment, - offset + 8L * bulkSize + i * Integer.BYTES, + 8L * bulkSize + i * Integer.BYTES, ByteOrder.LITTLE_ENDIAN ).convert(VectorOperators.I2F, 0); var additionalCorrections = FloatVector.fromMemorySegment( FLOAT_SPECIES_128, memorySegment, - offset + 12L * bulkSize + i * Float.BYTES, + 12L * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN ); var qcDist = FloatVector.fromArray(FLOAT_SPECIES_128, scores, i); @@ -457,8 +489,8 @@ private float applyCorrections128Bulk( } } if (limit < bulkSize) { - // missing vectors to score maxScore = applyCorrectionsIndividually( + memorySegment, queryAdditionalCorrection, similarityFunction, centroidDp, @@ -466,14 +498,13 @@ private float applyCorrections128Bulk( scores, bulkSize, limit, - offset, + 0, ay, ly, y1, maxScore ); } - in.seek(offset + 16L * bulkSize); return maxScore; } @@ -487,31 +518,59 @@ private float applyCorrections256Bulk( float[] scores, int bulkSize ) throws IOException { + return IndexInputUtils.withSlice( + in, + 16L * bulkSize, + this::getScratch, + memorySegment -> applyCorrections256BulkImpl( + memorySegment, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize, + queryLowerInterval, + queryUpperInterval, + queryComponentSum + ) + ); + } + + private float applyCorrections256BulkImpl( + MemorySegment memorySegment, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores, + int bulkSize, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum + ) { int limit = FLOAT_SPECIES_256.loopBound(bulkSize); int i = 0; - long offset = in.getFilePointer(); float ay = queryLowerInterval; float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; float y1 = queryComponentSum; float maxScore = Float.NEGATIVE_INFINITY; for (; i < limit; i += FLOAT_SPECIES_256.length()) { - var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); var lx = FloatVector.fromMemorySegment( FLOAT_SPECIES_256, memorySegment, - offset + 4L * bulkSize + i * Float.BYTES, + 4L * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN ).sub(ax).mul(TWO_BIT_SCALE); var targetComponentSums = IntVector.fromMemorySegment( INT_SPECIES_256, memorySegment, - offset + 8L * bulkSize + i * Integer.BYTES, + 8L * bulkSize + i * Integer.BYTES, ByteOrder.LITTLE_ENDIAN ).convert(VectorOperators.I2F, 0); var additionalCorrections = FloatVector.fromMemorySegment( FLOAT_SPECIES_256, memorySegment, - offset + 12L * bulkSize + i * Float.BYTES, + 12L * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN ); var qcDist = FloatVector.fromArray(FLOAT_SPECIES_256, scores, i); @@ -550,6 +609,7 @@ private float applyCorrections256Bulk( if (limit < bulkSize) { // missing vectors to score maxScore = applyCorrectionsIndividually( + memorySegment, queryAdditionalCorrection, similarityFunction, centroidDp, @@ -557,14 +617,13 @@ private float applyCorrections256Bulk( scores, bulkSize, limit, - offset, + 0, ay, ly, y1, maxScore ); } - in.seek(offset + 16L * bulkSize); return maxScore; } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java index 948596a9c9a16..6b0fa715263f1 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java @@ -18,6 +18,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.simdvec.internal.IndexInputUtils; import java.io.IOException; import java.lang.foreign.Arena; @@ -33,8 +34,8 @@ /** Panamized scorer for quantized vectors stored as a {@link MemorySegment}. */ final class MSInt4SymmetricESNextOSQVectorsScorer extends MemorySegmentESNextOSQVectorsScorer.MemorySegmentScorer { - MSInt4SymmetricESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize, MemorySegment memorySegment) { - super(in, dimensions, dataLength, bulkSize, memorySegment); + MSInt4SymmetricESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize) { + super(in, dimensions, dataLength, bulkSize); } @Override @@ -56,9 +57,10 @@ public long quantizeScore(byte[] q) throws IOException { } private long nativeQuantizeScore(byte[] q) throws IOException { - long offset = in.getFilePointer(); - var datasetMemorySegment = memorySegment.asSlice(offset, length); + return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> nativeQuantizeScoreImpl(q, segment, length)); + } + private static long nativeQuantizeScoreImpl(byte[] q, MemorySegment datasetMemorySegment, int length) { final long qScore; if (SUPPORTS_HEAP_SEGMENTS) { var queryMemorySegment = MemorySegment.ofArray(q); @@ -70,7 +72,6 @@ private long nativeQuantizeScore(byte[] q) throws IOException { qScore = dotProductD4Q4(datasetMemorySegment, queryMemorySegment, length); } } - in.skipBytes(length); return qScore; } @@ -91,25 +92,28 @@ private long quantizeScoreSymmetric256(byte[] q) throws IOException { } private long quantizeScore256(byte[] q) throws IOException { + int size = length / 4; + return IndexInputUtils.withSlice(in, size, this::getScratch, segment -> quantizeScore256Impl(q, segment, size)); + } + + private static long quantizeScore256Impl(byte[] q, MemorySegment memorySegment, int size) { long subRet0 = 0; long subRet1 = 0; long subRet2 = 0; long subRet3 = 0; int i = 0; - long offset = in.getFilePointer(); - int size = length / 4; if (size >= ByteVector.SPECIES_256.vectorByteSize() * 2) { int limit = ByteVector.SPECIES_256.loopBound(size); var sum0 = LongVector.zero(LONG_SPECIES_256); var sum1 = LongVector.zero(LONG_SPECIES_256); var sum2 = LongVector.zero(LONG_SPECIES_256); var sum3 = LongVector.zero(LONG_SPECIES_256); - for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) { + for (; i < limit; i += ByteVector.SPECIES_256.length()) { var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size).reinterpretAsLongs(); var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size * 2).reinterpretAsLongs(); var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size * 3).reinterpretAsLongs(); - var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, i, ByteOrder.LITTLE_ENDIAN); sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); @@ -127,12 +131,12 @@ private long quantizeScore256(byte[] q) throws IOException { var sum2 = LongVector.zero(LONG_SPECIES_128); var sum3 = LongVector.zero(LONG_SPECIES_128); int limit = ByteVector.SPECIES_128.loopBound(size); - for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) { + for (; i < limit; i += ByteVector.SPECIES_128.length()) { var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size).reinterpretAsLongs(); var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 2).reinterpretAsLongs(); var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 3).reinterpretAsLongs(); - var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN); sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); @@ -144,23 +148,22 @@ private long quantizeScore256(byte[] q) throws IOException { subRet3 += sum3.reduceLanes(VectorOperators.ADD); } // process scalar tail - in.seek(offset); for (final int upperBound = size & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); + final long value = memorySegment.get(LAYOUT_LE_LONG, i); subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + size) & value); subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * size) & value); subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * size) & value); } for (final int upperBound = size & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); + final int value = memorySegment.get(LAYOUT_LE_INT, i); subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + size) & value); subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * size) & value); subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * size) & value); } for (; i < size; i++) { - int dValue = in.readByte() & 0xFF; + int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF; subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); subRet1 += Integer.bitCount((q[i + size] & dValue) & 0xFF); subRet2 += Integer.bitCount((q[i + 2 * size] & dValue) & 0xFF); @@ -170,21 +173,24 @@ private long quantizeScore256(byte[] q) throws IOException { } private long quantizeScore128(byte[] q) throws IOException { + int size = length / 4; + return IndexInputUtils.withSlice(in, size, this::getScratch, segment -> quantizeScore128Impl(q, segment, size)); + } + + private static long quantizeScore128Impl(byte[] q, MemorySegment memorySegment, int size) { long subRet0 = 0; long subRet1 = 0; long subRet2 = 0; long subRet3 = 0; int i = 0; - long offset = in.getFilePointer(); var sum0 = IntVector.zero(INT_SPECIES_128); var sum1 = IntVector.zero(INT_SPECIES_128); var sum2 = IntVector.zero(INT_SPECIES_128); var sum3 = IntVector.zero(INT_SPECIES_128); - int size = length / 4; int limit = ByteVector.SPECIES_128.loopBound(size); - for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) { - var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN); var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size).reinterpretAsInts(); var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 2).reinterpretAsInts(); @@ -199,23 +205,22 @@ private long quantizeScore128(byte[] q) throws IOException { subRet2 += sum2.reduceLanes(VectorOperators.ADD); subRet3 += sum3.reduceLanes(VectorOperators.ADD); // process scalar tail - in.seek(offset); for (final int upperBound = size & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); + final long value = memorySegment.get(LAYOUT_LE_LONG, i); subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + size) & value); subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * size) & value); subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * size) & value); } for (final int upperBound = size & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); + final int value = memorySegment.get(LAYOUT_LE_INT, i); subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + size) & value); subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * size) & value); subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * size) & value); } for (; i < size; i++) { - int dValue = in.readByte() & 0xFF; + int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF; subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); subRet1 += Integer.bitCount((q[i + size] & dValue) & 0xFF); subRet2 += Integer.bitCount((q[i + 2 * size] & dValue) & 0xFF); @@ -259,13 +264,11 @@ public boolean quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOE } private void nativeQuantizeScoreBulk(MemorySegment queryMemorySegment, int count, MemorySegment scoresSegment) throws IOException { - long initialOffset = in.getFilePointer(); var datasetLengthInBytes = (long) length * count; - MemorySegment datasetSegment = memorySegment.asSlice(initialOffset, datasetLengthInBytes); - - dotProductD4Q4Bulk(datasetSegment, queryMemorySegment, length, count, scoresSegment); - - in.skipBytes(datasetLengthInBytes); + IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, datasetSegment -> { + dotProductD4Q4Bulk(datasetSegment, queryMemorySegment, length, count, scoresSegment); + return null; + }); } private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException { @@ -371,24 +374,25 @@ private float nativeApplyCorrectionsBulk( MemorySegment scoresSegment, int bulkSize ) throws IOException { - long offset = in.getFilePointer(); - - final float maxScore = ScoreCorrections.nativeApplyCorrectionsBulk( - similarityFunction, - memorySegment.asSlice(offset), - bulkSize, - dimensions, - queryLowerInterval, - queryUpperInterval, - queryComponentSum, - queryAdditionalCorrection, - FOUR_BIT_SCALE, - FOUR_BIT_SCALE, - centroidDp, - scoresSegment + return IndexInputUtils.withSlice( + in, + 16L * bulkSize, + this::getScratch, + memorySegment -> ScoreCorrections.nativeApplyCorrectionsBulk( + similarityFunction, + memorySegment, + bulkSize, + dimensions, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + FOUR_BIT_SCALE, + FOUR_BIT_SCALE, + centroidDp, + scoresSegment + ) ); - in.seek(offset + 16L * bulkSize); - return maxScore; } private float applyCorrections128Bulk( @@ -401,36 +405,62 @@ private float applyCorrections128Bulk( float[] scores, int bulkSize ) throws IOException { + return IndexInputUtils.withSlice( + in, + 16L * bulkSize, + this::getScratch, + memorySegment -> applyCorrections128BulkImpl( + memorySegment, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize, + queryLowerInterval, + queryUpperInterval, + queryComponentSum + ) + ); + } + + private float applyCorrections128BulkImpl( + MemorySegment memorySegment, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores, + int bulkSize, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum + ) { int limit = FLOAT_SPECIES_128.loopBound(bulkSize); int i = 0; - long offset = in.getFilePointer(); float ay = queryLowerInterval; float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; float y1 = queryComponentSum; float maxScore = Float.NEGATIVE_INFINITY; for (; i < limit; i += FLOAT_SPECIES_128.length()) { - var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); var lx = FloatVector.fromMemorySegment( FLOAT_SPECIES_128, memorySegment, - offset + 4L * bulkSize + i * Float.BYTES, + 4L * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN ).sub(ax).mul(FOUR_BIT_SCALE); var targetComponentSums = IntVector.fromMemorySegment( INT_SPECIES_128, memorySegment, - offset + 8L * bulkSize + i * Integer.BYTES, + 8L * bulkSize + i * Integer.BYTES, ByteOrder.LITTLE_ENDIAN ).convert(VectorOperators.I2F, 0); var additionalCorrections = FloatVector.fromMemorySegment( FLOAT_SPECIES_128, memorySegment, - offset + 12L * bulkSize + i * Float.BYTES, + 12L * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN ); var qcDist = FloatVector.fromArray(FLOAT_SPECIES_128, scores, i); - // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * - // qcDist; var res1 = ax.mul(ay).mul(dimensions); var res2 = lx.mul(ay).mul(targetComponentSums); var res3 = ax.mul(ly).mul(y1); @@ -464,6 +494,7 @@ private float applyCorrections128Bulk( if (limit < bulkSize) { // missing vectors to score maxScore = applyCorrectionsIndividually( + memorySegment, queryAdditionalCorrection, similarityFunction, centroidDp, @@ -471,14 +502,13 @@ private float applyCorrections128Bulk( scores, bulkSize, limit, - offset, + 0, ay, ly, y1, maxScore ); } - in.seek(offset + 16L * bulkSize); return maxScore; } @@ -492,31 +522,59 @@ private float applyCorrections256Bulk( float[] scores, int bulkSize ) throws IOException { + return IndexInputUtils.withSlice( + in, + 16L * bulkSize, + this::getScratch, + memorySegment -> applyCorrections256BulkImpl( + memorySegment, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores, + bulkSize, + queryLowerInterval, + queryUpperInterval, + queryComponentSum + ) + ); + } + + private float applyCorrections256BulkImpl( + MemorySegment memorySegment, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores, + int bulkSize, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum + ) { int limit = FLOAT_SPECIES_256.loopBound(bulkSize); int i = 0; - long offset = in.getFilePointer(); float ay = queryLowerInterval; float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; float y1 = queryComponentSum; float maxScore = Float.NEGATIVE_INFINITY; for (; i < limit; i += FLOAT_SPECIES_256.length()) { - var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); var lx = FloatVector.fromMemorySegment( FLOAT_SPECIES_256, memorySegment, - offset + 4L * bulkSize + i * Float.BYTES, + 4L * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN ).sub(ax).mul(FOUR_BIT_SCALE); var targetComponentSums = IntVector.fromMemorySegment( INT_SPECIES_256, memorySegment, - offset + 8L * bulkSize + i * Integer.BYTES, + 8L * bulkSize + i * Integer.BYTES, ByteOrder.LITTLE_ENDIAN ).convert(VectorOperators.I2F, 0); var additionalCorrections = FloatVector.fromMemorySegment( FLOAT_SPECIES_256, memorySegment, - offset + 12L * bulkSize + i * Float.BYTES, + 12L * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN ); var qcDist = FloatVector.fromArray(FLOAT_SPECIES_256, scores, i); @@ -555,6 +613,7 @@ private float applyCorrections256Bulk( if (limit < bulkSize) { // missing vectors to score maxScore = applyCorrectionsIndividually( + memorySegment, queryAdditionalCorrection, similarityFunction, centroidDp, @@ -562,14 +621,13 @@ private float applyCorrections256Bulk( scores, bulkSize, limit, - offset, + 0, ay, ly, y1, maxScore ); } - in.seek(offset + 16L * bulkSize); return maxScore; } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java index 6d923ff40f7f9..cc0f4f52a56d7 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java @@ -21,9 +21,11 @@ import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.internal.IndexInputUtils; import java.io.IOException; import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; import java.nio.ByteOrder; import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; @@ -46,11 +48,20 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore private static final VectorSpecies FLOAT_SPECIES_128 = FloatVector.SPECIES_128; private static final VectorSpecies FLOAT_SPECIES_256 = FloatVector.SPECIES_256; - private final MemorySegment memorySegment; + static final ValueLayout.OfLong LAYOUT_LE_LONG = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + static final ValueLayout.OfInt LAYOUT_LE_INT = ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); - public MemorySegmentES91OSQVectorsScorer(IndexInput in, int dimensions, int bulkSize, MemorySegment memorySegment) { + private byte[] scratch; + + public MemorySegmentES91OSQVectorsScorer(IndexInput in, int dimensions, int bulkSize) { super(in, dimensions, bulkSize); - this.memorySegment = memorySegment; + } + + private byte[] getScratch(int len) { + if (scratch == null || scratch.length < len) { + scratch = new byte[len]; + } + return scratch; } @Override @@ -68,24 +79,27 @@ public long quantizeScore(byte[] q) throws IOException { } private long quantizeScore256(byte[] q) throws IOException { + return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> quantizeScore256Impl(q, segment, length)); + } + + private static long quantizeScore256Impl(byte[] q, MemorySegment memorySegment, int length) { long subRet0 = 0; long subRet1 = 0; long subRet2 = 0; long subRet3 = 0; int i = 0; - long offset = in.getFilePointer(); if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { int limit = ByteVector.SPECIES_256.loopBound(length); var sum0 = LongVector.zero(LONG_SPECIES_256); var sum1 = LongVector.zero(LONG_SPECIES_256); var sum2 = LongVector.zero(LONG_SPECIES_256); var sum3 = LongVector.zero(LONG_SPECIES_256); - for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) { + for (; i < limit; i += ByteVector.SPECIES_256.length()) { var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs(); var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs(); var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs(); - var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, i, ByteOrder.LITTLE_ENDIAN); sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); @@ -103,12 +117,12 @@ private long quantizeScore256(byte[] q) throws IOException { var sum2 = LongVector.zero(LONG_SPECIES_128); var sum3 = LongVector.zero(LONG_SPECIES_128); int limit = ByteVector.SPECIES_128.loopBound(length); - for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) { + for (; i < limit; i += ByteVector.SPECIES_128.length()) { var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs(); var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs(); var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs(); - var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN); sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); @@ -120,23 +134,22 @@ private long quantizeScore256(byte[] q) throws IOException { subRet3 += sum3.reduceLanes(VectorOperators.ADD); } // process scalar tail - in.seek(offset); for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); + final long value = memorySegment.get(LAYOUT_LE_LONG, i); subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); } for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); + final int value = memorySegment.get(LAYOUT_LE_INT, i); subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); } for (; i < length; i++) { - int dValue = in.readByte() & 0xFF; + int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF; subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); @@ -146,20 +159,23 @@ private long quantizeScore256(byte[] q) throws IOException { } private long quantizeScore128(byte[] q) throws IOException { + return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> quantizeScore128Impl(q, segment, length)); + } + + private static long quantizeScore128Impl(byte[] q, MemorySegment memorySegment, int length) { long subRet0 = 0; long subRet1 = 0; long subRet2 = 0; long subRet3 = 0; int i = 0; - long offset = in.getFilePointer(); var sum0 = IntVector.zero(INT_SPECIES_128); var sum1 = IntVector.zero(INT_SPECIES_128); var sum2 = IntVector.zero(INT_SPECIES_128); var sum3 = IntVector.zero(INT_SPECIES_128); int limit = ByteVector.SPECIES_128.loopBound(length); - for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) { - var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN); var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts(); var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts(); @@ -174,23 +190,22 @@ private long quantizeScore128(byte[] q) throws IOException { subRet2 += sum2.reduceLanes(VectorOperators.ADD); subRet3 += sum3.reduceLanes(VectorOperators.ADD); // process scalar tail - in.seek(offset); for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); + final long value = memorySegment.get(LAYOUT_LE_LONG, i); subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); } for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); + final int value = memorySegment.get(LAYOUT_LE_INT, i); subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); } for (; i < length; i++) { - int dValue = in.readByte() & 0xFF; + int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF; subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); @@ -216,13 +231,21 @@ public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOExce } private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException { + var datasetLengthInBytes = (long) length * count; + IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, segment -> { + quantizeScore128BulkImpl(q, count, scores, segment, length); + return null; + }); + } + + private static void quantizeScore128BulkImpl(byte[] q, int count, float[] scores, MemorySegment memorySegment, int length) { + long offset = 0L; for (int iter = 0; iter < count; iter++) { long subRet0 = 0; long subRet1 = 0; long subRet2 = 0; long subRet3 = 0; int i = 0; - long offset = in.getFilePointer(); var sum0 = IntVector.zero(INT_SPECIES_128); var sum1 = IntVector.zero(INT_SPECIES_128); @@ -245,23 +268,22 @@ private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IO subRet2 += sum2.reduceLanes(VectorOperators.ADD); subRet3 += sum3.reduceLanes(VectorOperators.ADD); // process scalar tail - in.seek(offset); - for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES, offset += Long.BYTES) { + final long value = memorySegment.get(LAYOUT_LE_LONG, offset); subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); } - for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES, offset += Integer.BYTES) { + final int value = memorySegment.get(LAYOUT_LE_INT, offset); subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); } - for (; i < length; i++) { - int dValue = in.readByte() & 0xFF; + for (; i < length; i++, offset++) { + int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, offset) & 0xFF; subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); @@ -272,13 +294,21 @@ private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IO } private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IOException { + var datasetLengthInBytes = (long) length * count; + IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, segment -> { + quantizeScore256BulkImpl(q, count, scores, segment, length); + return null; + }); + } + + private static void quantizeScore256BulkImpl(byte[] q, int count, float[] scores, MemorySegment memorySegment, int length) { + long offset = 0L; for (int iter = 0; iter < count; iter++) { long subRet0 = 0; long subRet1 = 0; long subRet2 = 0; long subRet3 = 0; int i = 0; - long offset = in.getFilePointer(); if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { int limit = ByteVector.SPECIES_256.loopBound(length); var sum0 = LongVector.zero(LONG_SPECIES_256); @@ -325,23 +355,22 @@ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IO subRet3 += sum3.reduceLanes(VectorOperators.ADD); } // process scalar tail - in.seek(offset); - for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - final long value = in.readLong(); + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES, offset += Long.BYTES) { + final long value = memorySegment.get(LAYOUT_LE_LONG, offset); subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); } - for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - final int value = in.readInt(); + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES, offset += Integer.BYTES) { + final int value = memorySegment.get(LAYOUT_LE_INT, offset); subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); } - for (; i < length; i++) { - int dValue = in.readByte() & 0xFF; + for (; i < length; i++, offset++) { + int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, offset) & 0xFF; subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); @@ -366,26 +395,38 @@ public float scoreBulk( // 128 / 8 == 16 if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { - return score256Bulk( - q, - queryLowerInterval, - queryUpperInterval, - queryComponentSum, - queryAdditionalCorrection, - similarityFunction, - centroidDp, - scores + return IndexInputUtils.withSlice( + in, + (14L + length) * this.bulkSize, + this::getScratch, + segment -> score256Bulk( + segment, + q, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores + ) ); } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { - return score128Bulk( - q, - queryLowerInterval, - queryUpperInterval, - queryComponentSum, - queryAdditionalCorrection, - similarityFunction, - centroidDp, - scores + return IndexInputUtils.withSlice( + in, + (14L + length) * this.bulkSize, + this::getScratch, + segment -> score128Bulk( + segment, + q, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores + ) ); } } @@ -402,6 +443,7 @@ public float scoreBulk( } private float score128Bulk( + MemorySegment memorySegment, byte[] q, float queryLowerInterval, float queryUpperInterval, @@ -410,11 +452,11 @@ private float score128Bulk( VectorSimilarityFunction similarityFunction, float centroidDp, float[] scores - ) throws IOException { - quantizeScore128Bulk(q, this.bulkSize, scores); + ) { + quantizeScore128BulkImpl(q, this.bulkSize, scores, memorySegment, length); int limit = FLOAT_SPECIES_128.loopBound(this.bulkSize); int i = 0; - long offset = in.getFilePointer(); + long offset = (long) length * this.bulkSize; float ay = queryLowerInterval; float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; float y1 = queryComponentSum; @@ -472,11 +514,11 @@ private float score128Bulk( } } } - in.seek(offset + 14L * this.bulkSize); return maxScore; } private float score256Bulk( + MemorySegment memorySegment, byte[] q, float queryLowerInterval, float queryUpperInterval, @@ -485,11 +527,11 @@ private float score256Bulk( VectorSimilarityFunction similarityFunction, float centroidDp, float[] scores - ) throws IOException { - quantizeScore256Bulk(q, this.bulkSize, scores); + ) { + quantizeScore256BulkImpl(q, this.bulkSize, scores, memorySegment, length); int limit = FLOAT_SPECIES_256.loopBound(this.bulkSize); int i = 0; - long offset = in.getFilePointer(); + long offset = (long) length * this.bulkSize; float ay = queryLowerInterval; float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; float y1 = queryComponentSum; @@ -547,7 +589,6 @@ private float score256Bulk( } } } - in.seek(offset + 14L * this.bulkSize); return maxScore; } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java index 9fcf70ebb7c92..434c31004e5ff 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java @@ -16,9 +16,12 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.core.DirectAccessInput; import org.elasticsearch.nativeaccess.NativeAccess; import org.elasticsearch.simdvec.ESNextOSQVectorsScorer; +import org.elasticsearch.simdvec.internal.IndexInputUtils; import java.io.IOException; import java.lang.foreign.MemorySegment; @@ -39,18 +42,17 @@ public MemorySegmentESNextOSQVectorsScorer( byte indexBits, int dimensions, int dataLength, - int bulkSize, - MemorySegment memorySegment + int bulkSize ) { super(in, queryBits, indexBits, dimensions, dataLength); if (queryBits == 4 && indexBits == 1) { - this.scorer = new MSBitToInt4ESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize, memorySegment); + this.scorer = new MSBitToInt4ESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize); } else if (queryBits == 4 && indexBits == 4) { - this.scorer = new MSInt4SymmetricESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize, memorySegment); + this.scorer = new MSInt4SymmetricESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize); } else if (queryBits == 4 && indexBits == 2) { - this.scorer = new MSDibitToInt4ESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize, memorySegment); + this.scorer = new MSDibitToInt4ESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize); } else if (queryBits == 7 && indexBits == 7) { - this.scorer = new MSD7Q7ESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize, memorySegment); + this.scorer = new MSD7Q7ESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize); } else { throw new IllegalArgumentException("Unsupported query/index bits combination: " + queryBits + "/" + indexBits); } @@ -171,20 +173,48 @@ abstract static sealed class MemorySegmentScorer permits MSBitToInt4ESNextOSQVec static final VectorSpecies FLOAT_SPECIES_128 = FloatVector.SPECIES_128; static final VectorSpecies FLOAT_SPECIES_256 = FloatVector.SPECIES_256; - protected final MemorySegment memorySegment; + static final ValueLayout.OfLong LAYOUT_LE_LONG = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + static final ValueLayout.OfInt LAYOUT_LE_INT = ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + protected final IndexInput in; protected final int length; protected final int dimensions; protected final int bulkSize; - MemorySegmentScorer(IndexInput in, int dimensions, int dataLength, int bulkSize, MemorySegment segment) { + private byte[] scratch; + + /** + * Creates a new MemorySegmentScorer. The index input must be a + * {@link MemorySegmentAccessInput} or {@link DirectAccessInput}; + * otherwise an {@link IllegalArgumentException} is thrown. + * + *

Memory segment access is handled by + * {@link org.elasticsearch.simdvec.internal.IndexInputUtils#withSlice + * IndexInputUtils.withSlice}, which probes the index input for + * {@link MemorySegmentAccessInput} / + * {@link DirectAccessInput} support and + * falls back to a heap copy when neither is available. + * + * @param in the index input + * @param dimensions the vector dimensions + * @param dataLength the length in bytes, per data vector + * @param bulkSize the number of vectors per bulk + */ + MemorySegmentScorer(IndexInput in, int dimensions, int dataLength, int bulkSize) { + IndexInputUtils.checkInputType(in); this.in = in; this.length = dataLength; this.dimensions = dimensions; - this.memorySegment = segment; this.bulkSize = bulkSize; } + protected byte[] getScratch(int len) { + if (scratch == null || scratch.length < len) { + scratch = new byte[len]; + } + return scratch; + } + abstract long quantizeScore(byte[] q) throws IOException; abstract boolean quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException; @@ -225,6 +255,7 @@ abstract float scoreBulk( ) throws IOException; protected float applyCorrectionsIndividually( + MemorySegment memorySegment, float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, @@ -286,4 +317,5 @@ protected float applyCorrectionsIndividually( return maxScore; } } + } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java index a87e3626210ce..534b737c96984 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java @@ -11,7 +11,6 @@ import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.store.MemorySegmentAccessInput; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import org.elasticsearch.simdvec.ES92Int7VectorsScorer; import org.elasticsearch.simdvec.ESNextOSQVectorsScorer; @@ -19,7 +18,6 @@ import org.elasticsearch.simdvec.internal.MemorySegmentES92Int7VectorsScorer; import java.io.IOException; -import java.lang.foreign.MemorySegment; final class PanamaESVectorizationProvider extends ESVectorizationProvider { @@ -42,43 +40,30 @@ public ESNextOSQVectorsScorer newESNextOSQVectorsScorer( int dimension, int dataLength, int bulkSize - ) throws IOException { - IndexInput unwrappedInput = FilterIndexInput.unwrapOnlyTest(input); - unwrappedInput = MemorySegmentAccessInputAccess.unwrap(unwrappedInput); + ) { if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS - && unwrappedInput instanceof MemorySegmentAccessInput msai && ((queryBits == 4 && (indexBits == 1 || indexBits == 2 || indexBits == 4)) || (queryBits == 7 && indexBits == 7))) { - MemorySegment ms = msai.segmentSliceOrNull(0, unwrappedInput.length()); - if (ms != null) { - return new MemorySegmentESNextOSQVectorsScorer(unwrappedInput, queryBits, indexBits, dimension, dataLength, bulkSize, ms); - } + IndexInput unwrappedInput = FilterIndexInput.unwrapOnlyTest(input); + unwrappedInput = MemorySegmentAccessInputAccess.unwrap(unwrappedInput); + return new MemorySegmentESNextOSQVectorsScorer(unwrappedInput, queryBits, indexBits, dimension, dataLength, bulkSize); } return new ESNextOSQVectorsScorer(input, queryBits, indexBits, dimension, dataLength, bulkSize); } @Override public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension, int bulkSize) throws IOException { - IndexInput unwrappedInput = FilterIndexInput.unwrapOnlyTest(input); - unwrappedInput = MemorySegmentAccessInputAccess.unwrap(unwrappedInput); - if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS && unwrappedInput instanceof MemorySegmentAccessInput msai) { - MemorySegment ms = msai.segmentSliceOrNull(0, unwrappedInput.length()); - if (ms != null) { - return new MemorySegmentES91OSQVectorsScorer(unwrappedInput, dimension, bulkSize, ms); - } + if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + IndexInput unwrappedInput = FilterIndexInput.unwrapOnlyTest(input); + unwrappedInput = MemorySegmentAccessInputAccess.unwrap(unwrappedInput); + return new MemorySegmentES91OSQVectorsScorer(unwrappedInput, dimension, bulkSize); } return new OnHeapES91OSQVectorsScorer(input, dimension, bulkSize); } @Override - public ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension, int bulkSize) throws IOException { + public ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension, int bulkSize) { IndexInput unwrappedInput = FilterIndexInput.unwrapOnlyTest(input); unwrappedInput = MemorySegmentAccessInputAccess.unwrap(unwrappedInput); - if (unwrappedInput instanceof MemorySegmentAccessInput msai) { - MemorySegment ms = msai.segmentSliceOrNull(0, unwrappedInput.length()); - if (ms != null) { - return new MemorySegmentES92Int7VectorsScorer(unwrappedInput, dimension, bulkSize, ms); - } - } - return new ES92Int7VectorsScorer(input, dimension, bulkSize); + return new MemorySegmentES92Int7VectorsScorer(unwrappedInput, dimension, bulkSize); } } diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java index aa4f5c355923e..c08afba09c2be 100644 --- a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java @@ -20,8 +20,8 @@ public final class MemorySegmentES92Int7VectorsScorer extends MemorySegmentES92P private static final boolean NATIVE_SUPPORTED = NativeAccess.instance().getVectorSimilarityFunctions().isPresent(); - public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, int bulkSize, MemorySegment memorySegment) { - super(in, dimensions, bulkSize, memorySegment); + public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, int bulkSize) { + super(in, dimensions, bulkSize); } @Override @@ -37,23 +37,22 @@ public long int7DotProduct(byte[] q) throws IOException { } else { return panamaInt7DotProduct(q); } - } private long nativeInt7DotProduct(byte[] q) throws IOException { - final MemorySegment segment = memorySegment.asSlice(in.getFilePointer(), dimensions); - final MemorySegment querySegment = MemorySegment.ofArray(q); - final long res = Similarities.dotProductI7u(segment, querySegment, dimensions); - in.skipBytes(dimensions); - return res; + return IndexInputUtils.withSlice(in, dimensions, this::getScratch, segment -> { + final MemorySegment querySegment = MemorySegment.ofArray(q); + return Similarities.dotProductI7u(segment, querySegment, dimensions); + }); } private void nativeInt7DotProductBulk(byte[] q, int count, float[] scores) throws IOException { - final MemorySegment scoresSegment = MemorySegment.ofArray(scores); - final MemorySegment segment = memorySegment.asSlice(in.getFilePointer(), dimensions * count); - final MemorySegment querySegment = MemorySegment.ofArray(q); - Similarities.dotProductI7uBulk(segment, querySegment, dimensions, count, scoresSegment); - in.skipBytes(dimensions * count); + IndexInputUtils.withSlice(in, (long) dimensions * count, this::getScratch, segment -> { + final MemorySegment scoresSegment = MemorySegment.ofArray(scores); + final MemorySegment querySegment = MemorySegment.ofArray(q); + Similarities.dotProductI7uBulk(segment, querySegment, dimensions, count, scoresSegment); + return null; + }); } @Override diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ESNextOSQVectorsScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ESNextOSQVectorsScorerTests.java index eb8bd664d0196..0407f55471529 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ESNextOSQVectorsScorerTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ESNextOSQVectorsScorerTests.java @@ -109,6 +109,7 @@ public void testQuantizeScore() throws Exception { assertEquals(in.getFilePointer(), slice.getFilePointer()); } assertEquals((long) length * numVectors, slice.getFilePointer()); + assertEquals((long) length * numVectors, in.getFilePointer()); } } } diff --git a/libs/simdvec/src/test21/java/org/elasticsearch/simdvec/internal/IndexInputUtilsTests.java b/libs/simdvec/src/test21/java/org/elasticsearch/simdvec/internal/IndexInputUtilsTests.java new file mode 100644 index 0000000000000..5edfaa16a7af7 --- /dev/null +++ b/libs/simdvec/src/test21/java/org/elasticsearch/simdvec/internal/IndexInputUtilsTests.java @@ -0,0 +1,185 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.store.NIOFSDirectory; +import org.elasticsearch.core.CheckedConsumer; +import org.elasticsearch.core.DirectAccessInput; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.Arrays; + +/** + * Tests that {@link IndexInputUtils#withSlice} correctly handles all + * three input types: {@link MemorySegmentAccessInput} (mmap), + * {@link DirectAccessInput} (byte-buffer), and plain {@link IndexInput} + * (heap-copy fallback). + */ +public class IndexInputUtilsTests extends ESTestCase { + + private static final String FILE_NAME = "test.bin"; + + // -- withSlice path tests ------------------------------------------------- + + public void testWithSliceMemorySegmentAccessInput() throws Exception { + byte[] data = randomByteArrayOfLength(256); + try (Directory dir = new MMapDirectory(createTempDir())) { + writeData(dir, data); + try (IndexInput in = dir.openInput(FILE_NAME, IOContext.DEFAULT)) { + assertTrue(in instanceof MemorySegmentAccessInput); + verifyWithSlice(in, data); + } + } + } + + public void testWithSliceDirectAccessInput() throws Exception { + byte[] data = randomByteArrayOfLength(256); + try (Directory dir = new NIOFSDirectory(createTempDir())) { + writeData(dir, data); + try (IndexInput rawIn = dir.openInput(FILE_NAME, IOContext.DEFAULT)) { + IndexInput in = new DirectAccessWrapper("dai", rawIn, data); + assertTrue(in instanceof DirectAccessInput); + verifyWithSlice(in, data); + } + } + } + + public void testWithSlicePlainIndexInput() throws Exception { + byte[] data = randomByteArrayOfLength(256); + try (Directory dir = new NIOFSDirectory(createTempDir())) { + writeData(dir, data); + try (IndexInput in = dir.openInput(FILE_NAME, IOContext.DEFAULT)) { + assertFalse(in instanceof MemorySegmentAccessInput); + assertFalse(in instanceof DirectAccessInput); + verifyWithSlice(in, data); + } + } + } + + // -- constructor validation tests ----------------------------------------- + + public void testES92ConstructorAcceptsPlainInput() throws Exception { + byte[] data = randomByteArrayOfLength(256); + try (Directory dir = new NIOFSDirectory(createTempDir())) { + writeData(dir, data); + try (IndexInput in = dir.openInput(FILE_NAME, IOContext.DEFAULT)) { + new MemorySegmentES92Int7VectorsScorer(in, 64, 16); + } + } + } + + public void testES92ConstructorAcceptsMMapInput() throws Exception { + byte[] data = randomByteArrayOfLength(256); + try (Directory dir = new MMapDirectory(createTempDir())) { + writeData(dir, data); + try (IndexInput in = dir.openInput(FILE_NAME, IOContext.DEFAULT)) { + new MemorySegmentES92Int7VectorsScorer(in, 64, 16); + } + } + } + + public void testES92ConstructorAcceptsDirectAccessInput() throws Exception { + byte[] data = randomByteArrayOfLength(256); + try (Directory dir = new NIOFSDirectory(createTempDir())) { + writeData(dir, data); + try (IndexInput rawIn = dir.openInput(FILE_NAME, IOContext.DEFAULT)) { + IndexInput in = new DirectAccessWrapper("dai", rawIn, data); + new MemorySegmentES92Int7VectorsScorer(in, 64, 16); + } + } + } + + public void testES92ConstructorRejectsUnwrappedFilterIndexInput() throws Exception { + byte[] data = randomByteArrayOfLength(256); + try (Directory dir = new NIOFSDirectory(createTempDir())) { + writeData(dir, data); + try (IndexInput rawIn = dir.openInput(FILE_NAME, IOContext.DEFAULT)) { + IndexInput wrapped = new FilterIndexInput("plain-wrapper", rawIn) {}; + expectThrows(IllegalArgumentException.class, () -> new MemorySegmentES92Int7VectorsScorer(wrapped, 64, 16)); + } + } + } + + // -- helpers -------------------------------------------------------------- + + private void verifyWithSlice(IndexInput in, byte[] expectedData) throws IOException { + int firstChunk = 64; + byte[] result1 = IndexInputUtils.withSlice(in, firstChunk, byte[]::new, segment -> { + byte[] buf = new byte[(int) segment.byteSize()]; + MemorySegment.ofArray(buf).copyFrom(segment); + return buf; + }); + assertArrayEquals(Arrays.copyOfRange(expectedData, 0, firstChunk), result1); + assertEquals(firstChunk, in.getFilePointer()); + + int secondChunk = 128; + byte[] result2 = IndexInputUtils.withSlice(in, secondChunk, byte[]::new, segment -> { + byte[] buf = new byte[(int) segment.byteSize()]; + MemorySegment.ofArray(buf).copyFrom(segment); + return buf; + }); + assertArrayEquals(Arrays.copyOfRange(expectedData, firstChunk, firstChunk + secondChunk), result2); + assertEquals(firstChunk + secondChunk, in.getFilePointer()); + + int remaining = expectedData.length - firstChunk - secondChunk; + byte[] result3 = IndexInputUtils.withSlice(in, remaining, byte[]::new, segment -> { + byte[] buf = new byte[(int) segment.byteSize()]; + MemorySegment.ofArray(buf).copyFrom(segment); + return buf; + }); + assertArrayEquals(Arrays.copyOfRange(expectedData, firstChunk + secondChunk, expectedData.length), result3); + assertEquals(expectedData.length, in.getFilePointer()); + } + + private static void writeData(Directory dir, byte[] data) throws IOException { + try (IndexOutput out = dir.createOutput(FILE_NAME, IOContext.DEFAULT)) { + out.writeBytes(data, 0, data.length); + } + } + + /** + * Wraps an existing IndexInput with DirectAccessInput support, + * serving byte-buffer slices from the provided data array. + */ + static class DirectAccessWrapper extends FilterIndexInput implements DirectAccessInput { + private final byte[] data; + + DirectAccessWrapper(String resourceDescription, IndexInput delegate, byte[] data) { + super(resourceDescription, delegate); + this.data = data; + } + + @Override + public boolean withByteBufferSlice(long offset, long length, CheckedConsumer action) throws IOException { + ByteBuffer bb = ByteBuffer.wrap(data, (int) offset, (int) length).asReadOnlyBuffer(); + action.accept(bb); + return true; + } + + @Override + public IndexInput clone() { + return new DirectAccessWrapper("clone", in.clone(), data); + } + + @Override + public IndexInput slice(String sliceDescription, long offset, long length) throws IOException { + return new DirectAccessWrapper(sliceDescription, in.slice(sliceDescription, offset, length), data); + } + } +} diff --git a/qa/vector/build.gradle b/qa/vector/build.gradle index c078b80ed04f9..95390d3a68de0 100644 --- a/qa/vector/build.gradle +++ b/qa/vector/build.gradle @@ -80,6 +80,11 @@ dependencies { implementation project(':libs:logging') implementation project(':server') implementation project(':libs:gpu-codec') + implementation(project(':x-pack:plugin:searchable-snapshots')) { + capabilities { + requireCapability("org.elasticsearch.plugin:searchable-snapshots-test-artifacts") + } + } testImplementation project(":libs:x-content") testImplementation project(":test:framework") @@ -133,48 +138,6 @@ tasks.named("thirdPartyAudit").configure { 'com.google.appengine.api.urlfetch.URLFetchService', 'com.google.appengine.api.urlfetch.URLFetchServiceFactory', - // optional apache http client dependencies - 'org.apache.http.ConnectionReuseStrategy', - 'org.apache.http.Header', - 'org.apache.http.HttpEntity', - 'org.apache.http.HttpEntityEnclosingRequest', - 'org.apache.http.HttpHost', - 'org.apache.http.HttpRequest', - 'org.apache.http.HttpResponse', - 'org.apache.http.HttpVersion', - 'org.apache.http.RequestLine', - 'org.apache.http.StatusLine', - 'org.apache.http.client.AuthenticationHandler', - 'org.apache.http.client.HttpClient', - 'org.apache.http.client.HttpRequestRetryHandler', - 'org.apache.http.client.RedirectHandler', - 'org.apache.http.client.RequestDirector', - 'org.apache.http.client.UserTokenHandler', - 'org.apache.http.client.methods.HttpEntityEnclosingRequestBase', - 'org.apache.http.client.methods.HttpRequestBase', - 'org.apache.http.config.Registry', - 'org.apache.http.config.RegistryBuilder', - 'org.apache.http.conn.ClientConnectionManager', - 'org.apache.http.conn.ConnectionKeepAliveStrategy', - 'org.apache.http.conn.params.ConnManagerParams', - 'org.apache.http.conn.params.ConnRouteParams', - 'org.apache.http.conn.routing.HttpRoutePlanner', - 'org.apache.http.conn.scheme.PlainSocketFactory', - 'org.apache.http.conn.scheme.SchemeRegistry', - 'org.apache.http.conn.socket.PlainConnectionSocketFactory', - 'org.apache.http.conn.ssl.SSLSocketFactory', - 'org.apache.http.conn.ssl.X509HostnameVerifier', - 'org.apache.http.entity.AbstractHttpEntity', - 'org.apache.http.impl.client.DefaultHttpClient', - 'org.apache.http.impl.client.HttpClientBuilder', - 'org.apache.http.impl.conn.PoolingHttpClientConnectionManager', - 'org.apache.http.params.HttpConnectionParams', - 'org.apache.http.params.HttpParams', - 'org.apache.http.params.HttpProtocolParams', - 'org.apache.http.protocol.HttpContext', - 'org.apache.http.protocol.HttpProcessor', - 'org.apache.http.protocol.HttpRequestExecutor', - // grpc/proto stuff 'com.google.api.gax.grpc.GrpcCallContext', 'com.google.api.gax.grpc.GrpcCallSettings', @@ -219,7 +182,7 @@ tasks.named("thirdPartyAudit").configure { 'io.opentelemetry.sdk.metrics.export.PeriodicMetricReaderBuilder', 'io.opentelemetry.sdk.resources.Resource', ) - + if (buildParams.graalVmRuntime == false) { ignoreMissingClasses( 'org.graalvm.nativeimage.hosted.Feature', diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index 7452eb328fd8c..7346b77abd46d 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -25,7 +25,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.TieredMergePolicy; -import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.Directory; import org.apache.lucene.util.NamedThreadFactory; import org.elasticsearch.cli.ProcessInfo; import org.elasticsearch.common.Strings; @@ -63,6 +63,7 @@ import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -122,6 +123,44 @@ enum MergePolicyType { LOG_DOC } + /** + * Factory that creates a directory for a given index path. + */ + @FunctionalInterface + interface DirectoryFactory { + Directory create(Path indexPath) throws IOException; + } + + record DirectoryTypeConfig(DirectoryFactory factory, boolean shared, boolean preWarm) {} + + private static final Map directoryTypeRegistry = new ConcurrentHashMap<>(); + + static { + directoryTypeRegistry.put("default", new DirectoryTypeConfig(KnnIndexer::getDirectory, false, false)); + directoryTypeRegistry.put("frozen", new DirectoryTypeConfig(KnnIndexer::openFrozenDirectory, false, true)); + } + + /** + * Registers a custom directory type that can be referenced via {@code "directory_type"} + * in the test configuration JSON. + * + * @param name the name used in configuration (e.g. "serverless") + * @param factory creates a Directory for the given index path + * @param shared if true, a single directory instance is used for both write and read phases + * @param preWarm if true, the directory is pre-warmed before search + */ + static void registerDirectoryType(String name, DirectoryFactory factory, boolean shared, boolean preWarm) { + directoryTypeRegistry.put(name, new DirectoryTypeConfig(factory, shared, preWarm)); + } + + static DirectoryTypeConfig getDirectoryTypeConfig(String name) { + DirectoryTypeConfig config = directoryTypeRegistry.get(name); + if (config == null) { + throw new IllegalArgumentException("Unknown directory_type: '" + name + "'. Known types: " + directoryTypeRegistry.keySet()); + } + return config; + } + private static String formatIndexPath(TestConfiguration args) { List suffix = new ArrayList<>(); switch (args.indexType()) { @@ -347,49 +386,20 @@ public static void main(String[] args) throws Exception { Codec codec = createCodec(testConfiguration, exec); Path indexPath = PathUtils.get(indexPathName); MergePolicy mergePolicy = getMergePolicy(testConfiguration); - if (testConfiguration.reindex() || testConfiguration.forceMerge()) { - KnnIndexer knnIndexer = new KnnIndexer( - testConfiguration.docVectors(), - indexPath, - codec, - testConfiguration.indexThreads(), - testConfiguration.vectorEncoding().luceneEncoding, - testConfiguration.dimensions(), - testConfiguration.vectorSpace(), - testConfiguration.numDocs(), - mergePolicy, - testConfiguration.writerBufferSizeInMb(), - testConfiguration.writerMaxBufferedDocs() - ); - if (testConfiguration.reindex() == false && Files.exists(indexPath) == false) { - throw new IllegalArgumentException("Index path does not exist: " + indexPath); - } - if (testConfiguration.reindex()) { - knnIndexer.createIndex(indexResults); - } - if (testConfiguration.forceMerge()) { - knnIndexer.forceMerge(indexResults, testConfiguration.forceMergeMaxNumSegments()); - } - } - numSegments(indexPath, indexResults); - if (testConfiguration.queryVectors() != null && testConfiguration.numQueries() > 0) { - if (parsedArgs.warmUpIterations() > 0) { - logger.info("Running the searches for " + parsedArgs.warmUpIterations() + " warm up iterations"); - } - // Warm up - for (int warmUpCount = 0; warmUpCount < parsedArgs.warmUpIterations(); warmUpCount++) { - for (int i = 0; i < results.length; i++) { - var ignoreResults = new Results(indexPathName, indexType, testConfiguration.numDocs()); - KnnSearcher knnSearcher = new KnnSearcher(indexPath, testConfiguration); - knnSearcher.runSearch(ignoreResults, testConfiguration.searchParams().get(i)); - } - } - - for (int i = 0; i < results.length; i++) { - KnnSearcher knnSearcher = new KnnSearcher(indexPath, testConfiguration); - knnSearcher.runSearch(results[i], testConfiguration.searchParams().get(i)); - } - } + DirectoryTypeConfig dirConfig = getDirectoryTypeConfig(testConfiguration.directoryType()); + + runTestConfiguration( + testConfiguration, + indexPath, + codec, + mergePolicy, + dirConfig, + indexResults, + results, + parsedArgs, + indexPathName, + indexType + ); formattedResults.queryResults.addAll(List.of(results)); formattedResults.indexResults.add(indexResults); } finally { @@ -401,6 +411,120 @@ public static void main(String[] args) throws Exception { logger.info("Results: \n" + formattedResults); } + /** + * Runs indexing, merge, and search phases using the given directory configuration. + * When {@code dirConfig.shared()} is true, a single directory instance is used for all + * phases. Otherwise, separate directories are used for write and read. + */ + private static void runTestConfiguration( + TestConfiguration testConfiguration, + Path indexPath, + Codec codec, + MergePolicy mergePolicy, + DirectoryTypeConfig dirConfig, + Results indexResults, + Results[] results, + ParsedArgs parsedArgs, + String indexPathName, + String indexType + ) throws Exception { + Directory sharedDir = dirConfig.shared() ? dirConfig.factory().create(indexPath) : null; + try { + if (testConfiguration.reindex() || testConfiguration.forceMerge()) { + KnnIndexer knnIndexer = new KnnIndexer( + testConfiguration.docVectors(), + indexPath, + codec, + testConfiguration.indexThreads(), + testConfiguration.vectorEncoding().luceneEncoding, + testConfiguration.dimensions(), + testConfiguration.vectorSpace(), + testConfiguration.numDocs(), + mergePolicy, + testConfiguration.writerBufferSizeInMb(), + testConfiguration.writerMaxBufferedDocs() + ); + if (testConfiguration.reindex() == false && Files.exists(indexPath) == false) { + throw new IllegalArgumentException("Index path does not exist: " + indexPath); + } + if (testConfiguration.reindex()) { + reindex(knnIndexer, indexResults, sharedDir); + } + if (testConfiguration.forceMerge()) { + forceMerge(knnIndexer, indexResults, sharedDir, testConfiguration); + } + } + numSegments(indexPath, indexResults, sharedDir); + if (testConfiguration.queryVectors() != null && testConfiguration.numQueries() > 0) { + Directory readDir = sharedDir != null ? sharedDir : dirConfig.factory().create(indexPath); + try { + if (dirConfig.preWarm()) { + KnnSearcher.preWarmDirectory(readDir); + } + runSearches(testConfiguration, indexPath, readDir, results, parsedArgs, indexPathName, indexType); + } finally { + if (sharedDir == null) { + readDir.close(); + } + } + } + } finally { + if (sharedDir != null) { + sharedDir.close(); + } + } + } + + static void reindex(KnnIndexer knnIndexer, Results indexResults, Directory sharedDir) throws Exception { + if (sharedDir != null) { + knnIndexer.createIndex(indexResults, sharedDir); + } else { + knnIndexer.createIndex(indexResults); + } + } + + static void forceMerge(KnnIndexer knnIndexer, Results indexResults, Directory sharedDir, TestConfiguration testConfiguration) + throws Exception { + if (sharedDir != null) { + knnIndexer.forceMerge(indexResults, testConfiguration.forceMergeMaxNumSegments(), sharedDir); + } else { + knnIndexer.forceMerge(indexResults, testConfiguration.forceMergeMaxNumSegments()); + } + } + + static void numSegments(Path indexPath, Results indexResults, Directory sharedDir) throws IOException { + if (sharedDir != null) { + numSegments(sharedDir, indexResults); + } else { + numSegments(indexPath, indexResults); + } + } + + private static void runSearches( + TestConfiguration testConfiguration, + Path indexPath, + Directory dir, + Results[] results, + ParsedArgs parsedArgs, + String indexPathName, + String indexType + ) throws Exception { + if (parsedArgs.warmUpIterations() > 0) { + logger.info("Running the searches for " + parsedArgs.warmUpIterations() + " warm up iterations"); + } + for (int warmUpCount = 0; warmUpCount < parsedArgs.warmUpIterations(); warmUpCount++) { + for (int i = 0; i < results.length; i++) { + var ignoreResults = new Results(indexPathName, indexType, testConfiguration.numDocs()); + KnnSearcher knnSearcher = new KnnSearcher(indexPath, testConfiguration); + knnSearcher.runSearch(ignoreResults, testConfiguration.searchParams().get(i), dir); + } + } + for (int i = 0; i < results.length; i++) { + KnnSearcher knnSearcher = new KnnSearcher(indexPath, testConfiguration); + knnSearcher.runSearch(results[i], testConfiguration.searchParams().get(i), dir); + } + } + private static void checkQuantizeBits(TestConfiguration args) { switch (args.indexType()) { case IVF: @@ -431,13 +555,21 @@ private static MergePolicy getMergePolicy(TestConfiguration args) { } static void numSegments(Path indexPath, Results result) throws IOException { - try (FSDirectory dir = FSDirectory.open(indexPath); IndexReader reader = DirectoryReader.open(dir)) { + try (Directory dir = KnnIndexer.getDirectory(indexPath); IndexReader reader = DirectoryReader.open(dir)) { result.numSegments = reader.leaves().size(); } catch (IOException e) { throw new IOException("Failed to get segment count for index at " + indexPath, e); } } + static void numSegments(Directory dir, Results result) throws IOException { + try (IndexReader reader = DirectoryReader.open(dir)) { + result.numSegments = reader.leaves().size(); + } catch (IOException e) { + throw new IOException("Failed to get segment count for dir: " + dir, e); + } + } + static class FormattedResults { List indexResults = new ArrayList<>(); List queryResults = new ArrayList<>(); diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java index 491d3c8e553bd..98d629889feb8 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java @@ -108,6 +108,12 @@ class KnnIndexer { } void createIndex(KnnIndexTester.Results result) throws IOException, InterruptedException, ExecutionException { + try (Directory dir = getDirectory(indexPath)) { + createIndex(result, dir); + } + } + + void createIndex(KnnIndexTester.Results result, Directory dir) throws IOException, InterruptedException, ExecutionException { IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE); iwc.setCodec(codec); iwc.setMaxBufferedDocs(writerMaxBufferedDocs); @@ -140,7 +146,7 @@ public boolean isEnabled(String component) { long start = System.nanoTime(); AtomicInteger numDocsIndexed = new AtomicInteger(); - try (Directory dir = getDirectory(indexPath); IndexWriter iw = new IndexWriter(dir, iwc)) { + try (IndexWriter iw = new IndexWriter(dir, iwc)) { for (Path docsPath : this.docsPath) { int dim = this.dim; try (FileChannel in = FileChannel.open(docsPath)) { @@ -223,6 +229,12 @@ public boolean isEnabled(String component) { } void forceMerge(KnnIndexTester.Results results, int maxNumSegments) throws Exception { + try (Directory dir = getDirectory(indexPath)) { + forceMerge(results, maxNumSegments, dir); + } + } + + void forceMerge(KnnIndexTester.Results results, int maxNumSegments, Directory dir) throws Exception { IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.APPEND); iwc.setInfoStream(new PrintStreamInfoStream(System.out) { @Override @@ -234,7 +246,7 @@ public boolean isEnabled(String component) { iwc.setUseCompoundFile(false); logger.info("KnnIndexer: forceMerge in {} into {} segments", indexPath, maxNumSegments); long startNS = System.nanoTime(); - try (IndexWriter iw = new IndexWriter(getDirectory(indexPath), iwc)) { + try (IndexWriter iw = new IndexWriter(dir, iwc)) { iw.forceMerge(maxNumSegments); } long endNS = System.nanoTime(); @@ -252,6 +264,35 @@ static Directory getDirectory(Path indexPath) throws IOException { return dir; } + /** + * Opens a frozen (searchable snapshot) directory for the given index path. + */ + static Directory openFrozenDirectory(Path indexPath) throws IOException { + Path workPath = indexPath.resolveSibling(indexPath.getFileName() + ".snap_work"); + Files.createDirectories(workPath); + logger.info("Opening frozen snapshot directory for index at {} with work path {}", indexPath, workPath); + return openSearchableSnapshotDirectory(indexPath, workPath); + } + + /** + * Creates a directory backed by searchable snapshot infrastructure, wrapping an existing + * Lucene index on disk. Loaded via reflection because the factory resides in the + * searchable-snapshots test artifact (unnamed module) which cannot be directly referenced + * from this named module ({@code org.elasticsearch.test.knn}). + */ + private static Directory openSearchableSnapshotDirectory(Path indexPath, Path workPath) throws IOException { + try { + Class factoryClass = Class.forName("org.elasticsearch.xpack.searchablesnapshots.store.SearchableSnapshotDirectoryFactory"); + var method = factoryClass.getMethod("newDirectoryFromIndex", Path.class, Path.class); + return (Directory) method.invoke(null, indexPath, workPath); + } catch (Exception e) { + throw new IOException( + "Failed to create searchable snapshot directory. Ensure the searchable-snapshots test artifact is on the classpath.", + e + ); + } + } + private static BiFunction> getReadAdviceFunc() { return (name, context) -> { if (context.hints().contains(StandardIOBehaviorHint.INSTANCE) || name.endsWith(".cfs")) { diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index 26a0547015040..b89c80ab66d8b 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -51,10 +51,13 @@ import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.FixedBitSet; import org.elasticsearch.common.io.Channels; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.PathUtils; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.search.profile.query.QueryProfiler; @@ -123,7 +126,7 @@ class KnnSearcher { this.doPrecondition = testConfiguration.doPrecondition(); } - void runSearch(KnnIndexTester.Results finalResults, SearchParameters searchParameters) throws IOException { + void runSearch(KnnIndexTester.Results finalResults, SearchParameters searchParameters, Directory dir) throws IOException { Query filterQuery = searchParameters.filterSelectivity() < 1f ? generateRandomQuery( new Random(searchParameters.seed()), @@ -173,124 +176,122 @@ void runSearch(KnnIndexTester.Results finalResults, SearchParameters searchParam ); KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding, offsetByteSize); long startNS; - try (Directory dir = KnnIndexer.getDirectory(indexPath)) { - try (DirectoryReader reader = DirectoryReader.open(dir)) { - IndexSearcher searcher = searchParameters.searchThreads() > 1 - ? new IndexSearcher(reader, executorService) - : new IndexSearcher(reader); - byte[] targetBytes = new byte[dim]; - float[] target = new float[dim]; - // warm up - for (int i = 0; i < numQueryVectors; i++) { - if (vectorEncoding.equals(VectorEncoding.BYTE)) { - targetReader.next(targetBytes); - doVectorQuery(targetBytes, searcher, filterQuery, searchParameters); - } else { - targetReader.next(target); - doVectorQuery(target, searcher, filterQuery, searchParameters); - } - } - targetReader.reset(); - final IntConsumer[] queryConsumers = new IntConsumer[searchParameters.numSearchers()]; + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = searchParameters.searchThreads() > 1 + ? new IndexSearcher(reader, executorService) + : new IndexSearcher(reader); + byte[] targetBytes = new byte[dim]; + float[] target = new float[dim]; + // warm up + for (int i = 0; i < numQueryVectors; i++) { if (vectorEncoding.equals(VectorEncoding.BYTE)) { - byte[][] queries = new byte[numQueryVectors][dim]; - for (int i = 0; i < numQueryVectors; i++) { - targetReader.next(queries[i]); - } - for (int s = 0; s < searchParameters.numSearchers(); s++) { - queryConsumers[s] = i -> { - try { - results[i] = doVectorQuery(queries[i], searcher, filterQuery, searchParameters); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - }; - } + targetReader.next(targetBytes); + doVectorQuery(targetBytes, searcher, filterQuery, searchParameters); } else { - float[][] queries = new float[numQueryVectors][dim]; - for (int i = 0; i < numQueryVectors; i++) { - targetReader.next(queries[i]); - } - for (int s = 0; s < searchParameters.numSearchers(); s++) { - queryConsumers[s] = i -> { - try { - results[i] = doVectorQuery(queries[i], searcher, filterQuery, searchParameters); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - }; - } + targetReader.next(target); + doVectorQuery(target, searcher, filterQuery, searchParameters); + } + } + targetReader.reset(); + final IntConsumer[] queryConsumers = new IntConsumer[searchParameters.numSearchers()]; + if (vectorEncoding.equals(VectorEncoding.BYTE)) { + byte[][] queries = new byte[numQueryVectors][dim]; + for (int i = 0; i < numQueryVectors; i++) { + targetReader.next(queries[i]); } - int[][] querySplits = new int[searchParameters.numSearchers()][]; - int queriesPerSearcher = numQueryVectors / searchParameters.numSearchers(); for (int s = 0; s < searchParameters.numSearchers(); s++) { - int start = s * queriesPerSearcher; - int end = (s == searchParameters.numSearchers() - 1) ? numQueryVectors : (s + 1) * queriesPerSearcher; - querySplits[s] = new int[end - start]; - for (int i = start; i < end; i++) { - querySplits[s][i - start] = i; - } + queryConsumers[s] = i -> { + try { + results[i] = doVectorQuery(queries[i], searcher, filterQuery, searchParameters); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }; } - targetReader.reset(); - startNS = System.nanoTime(); - KnnIndexTester.ThreadDetails startThreadDetails = new KnnIndexTester.ThreadDetails(); - if (numSearchersExecutor != null) { - // use multiple searchers - var futures = new ArrayList>(); - for (int s = 0; s < searchParameters.numSearchers(); s++) { - int[] split = querySplits[s]; - IntConsumer queryConsumer = queryConsumers[s]; - futures.add(numSearchersExecutor.submit(() -> { - for (int j : split) { - queryConsumer.accept(j); - } - return null; - })); - } - for (Future future : futures) { + } else { + float[][] queries = new float[numQueryVectors][dim]; + for (int i = 0; i < numQueryVectors; i++) { + targetReader.next(queries[i]); + } + for (int s = 0; s < searchParameters.numSearchers(); s++) { + queryConsumers[s] = i -> { try { - future.get(); - } catch (Exception e) { - throw new RuntimeException("Error executing searcher thread", e); + results[i] = doVectorQuery(queries[i], searcher, filterQuery, searchParameters); + } catch (IOException e) { + throw new UncheckedIOException(e); } - } - } else { - // use a single searcher - for (int i = 0; i < numQueryVectors; i++) { - queryConsumers[0].accept(i); - } + }; } - KnnIndexTester.ThreadDetails endThreadDetails = new KnnIndexTester.ThreadDetails(); - elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); - long startCPUTimeNS = 0; - long endCPUTimeNS = 0; - for (int i = 0; i < startThreadDetails.threadInfos.length; i++) { - if (startThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher")) { - startCPUTimeNS += startThreadDetails.cpuTimesNS[i]; - } + } + int[][] querySplits = new int[searchParameters.numSearchers()][]; + int queriesPerSearcher = numQueryVectors / searchParameters.numSearchers(); + for (int s = 0; s < searchParameters.numSearchers(); s++) { + int start = s * queriesPerSearcher; + int end = (s == searchParameters.numSearchers() - 1) ? numQueryVectors : (s + 1) * queriesPerSearcher; + querySplits[s] = new int[end - start]; + for (int i = start; i < end; i++) { + querySplits[s][i - start] = i; } - - for (int i = 0; i < endThreadDetails.threadInfos.length; i++) { - if (endThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher")) { - endCPUTimeNS += endThreadDetails.cpuTimesNS[i]; + } + targetReader.reset(); + startNS = System.nanoTime(); + KnnIndexTester.ThreadDetails startThreadDetails = new KnnIndexTester.ThreadDetails(); + if (numSearchersExecutor != null) { + // use multiple searchers + var futures = new ArrayList>(); + for (int s = 0; s < searchParameters.numSearchers(); s++) { + int[] split = querySplits[s]; + IntConsumer queryConsumer = queryConsumers[s]; + futures.add(numSearchersExecutor.submit(() -> { + for (int j : split) { + queryConsumer.accept(j); + } + return null; + })); + } + for (Future future : futures) { + try { + future.get(); + } catch (Exception e) { + throw new RuntimeException("Error executing searcher thread", e); } } - totalCpuTimeMS = TimeUnit.NANOSECONDS.toMillis(endCPUTimeNS - startCPUTimeNS); - - // Fetch, validate and write result document ids. - StoredFields storedFields = reader.storedFields(); + } else { + // use a single searcher for (int i = 0; i < numQueryVectors; i++) { - totalVisited += results[i].totalHits.value(); - resultIds[i] = getResultIds(results[i], storedFields); + queryConsumers[0].accept(i); + } + } + KnnIndexTester.ThreadDetails endThreadDetails = new KnnIndexTester.ThreadDetails(); + elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); + long startCPUTimeNS = 0; + long endCPUTimeNS = 0; + for (int i = 0; i < startThreadDetails.threadInfos.length; i++) { + if (startThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher")) { + startCPUTimeNS += startThreadDetails.cpuTimesNS[i]; + } + } + + for (int i = 0; i < endThreadDetails.threadInfos.length; i++) { + if (endThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher")) { + endCPUTimeNS += endThreadDetails.cpuTimesNS[i]; } - logger.info( - "completed {} searches in {} ms: {} QPS CPU time={}ms", - numQueryVectors, - elapsed, - (1000L * numQueryVectors) / elapsed, - totalCpuTimeMS - ); } + totalCpuTimeMS = TimeUnit.NANOSECONDS.toMillis(endCPUTimeNS - startCPUTimeNS); + + // Fetch, validate and write result document ids. + StoredFields storedFields = reader.storedFields(); + for (int i = 0; i < numQueryVectors; i++) { + totalVisited += results[i].totalHits.value(); + resultIds[i] = getResultIds(results[i], storedFields); + } + logger.info( + "completed {} searches in {} ms: {} QPS CPU time={}ms", + numQueryVectors, + elapsed, + (1000L * numQueryVectors) / elapsed, + totalCpuTimeMS + ); } } logger.info("checking results"); @@ -309,6 +310,35 @@ void runSearch(KnnIndexTester.Results finalResults, SearchParameters searchParam finalResults.earlyTermination = searchParameters.earlyTermination(); } + /** + * Pre-warms the searchable snapshot cache by sequentially reading every + * file in the directory through a single {@link IndexInput} per file. + */ + static void preWarmDirectory(Directory dir) throws IOException { + long startNS = System.nanoTime(); + long totalBytes = 0; + byte[] buf = new byte[64 * 1024]; + for (String file : dir.listAll()) { + long fileLength = dir.fileLength(file); + try (IndexInput in = dir.openInput(file, IOContext.READONCE)) { + long remaining = fileLength; + while (remaining > 0) { + int toRead = (int) Math.min(buf.length, remaining); + in.readBytes(buf, 0, toRead); + remaining -= toRead; + } + } + totalBytes += fileLength; + } + long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); + logger.info( + "Pre-warmed searchable snapshot cache: {} across {} files in {} ms", + ByteSizeValue.ofBytes(totalBytes), + dir.listAll().length, + elapsedMS + ); + } + private static Query generateRandomQuery(Random random, Path indexPath, int size, float selectivity, boolean filterCached) throws IOException { FixedBitSet bitSet = new FixedBitSet(size); diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java index 151eb913715b9..7f41bb42b7316 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java @@ -69,7 +69,8 @@ record TestConfiguration( boolean doPrecondition, int preconditioningBlockDims, int flatVectorThreshold, - int secondaryClusterSize + int secondaryClusterSize, + String directoryType ) { static final ParseField DATASET_FIELD = new ParseField("dataset"); @@ -110,6 +111,7 @@ record TestConfiguration( static final ParseField FILTER_CACHED = new ParseField("filter_cache"); static final ParseField SEARCH_PARAMS = new ParseField("search_params"); static final ParseField FLAT_VECTOR_THRESHOLD = new ParseField("flat_vector_threshold"); + static final ParseField DIRECTORY_TYPE_FIELD = new ParseField("directory_type"); /** By default, in ES the default writer buffer size is 10% of the heap space * (see {@code IndexingMemoryController.INDEX_BUFFER_SIZE_SETTING}). @@ -174,6 +176,7 @@ static TestConfiguration fromXContent(XContentParser parser) throws Exception { PARSER.declareInt(Builder::setMergeWorkers, MERGE_WORKERS_FIELD); PARSER.declareInt(Builder::setFlatVectorThreshold, FLAT_VECTOR_THRESHOLD); PARSER.declareInt(Builder::setSecondaryClusterSize, SECONDARY_CLUSTER_SIZE); + PARSER.declareString(Builder::setDirectoryType, DIRECTORY_TYPE_FIELD); } public int numberOfSearchRuns() { @@ -230,6 +233,11 @@ public static String formattedParameterHelp() { "search_params", "array[object]", "Explicit per-search settings; each object may include search fields like num_candidates, k, and visit_percentage." + ), + new ParameterHelp( + "directory_type", + "string", + "Directory type: default (mmap), frozen (searchable snapshot), or custom types registered by external wrappers." ) ); @@ -303,6 +311,7 @@ static class Builder implements ToXContentObject { private int numMergeWorkers = 1; private int flatVectorThreshold = -1; // -1 mean use default (vectorPerCluster * 3) private int secondaryClusterSize = -1; + private String directoryType = "default"; /** * Elasticsearch does not set this explicitly, and in Lucene this setting is @@ -504,6 +513,11 @@ public Builder setSecondaryClusterSize(int secondaryClusterSize) { return this; } + public Builder setDirectoryType(String directoryType) { + this.directoryType = directoryType.toLowerCase(Locale.ROOT); + return this; + } + /* * Each dataset has a descriptor file, expected to be at gs:////.json, with contents of: { @@ -721,7 +735,8 @@ public TestConfiguration build() throws Exception { doPrecondition, preconditioningBlockDims, flatVectorThreshold, - secondaryClusterSize + secondaryClusterSize, + directoryType ); } @@ -779,6 +794,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(SEARCH_PARAMS.getPreferredName(), searchParams); } builder.field(FLAT_VECTOR_THRESHOLD.getPreferredName(), flatVectorThreshold); + builder.field(DIRECTORY_TYPE_FIELD.getPreferredName(), directoryType); return builder.endObject(); } diff --git a/server/src/main/java/org/elasticsearch/index/store/StoreMetricsIndexInput.java b/server/src/main/java/org/elasticsearch/index/store/StoreMetricsIndexInput.java index ede0a44fd7a53..a91a7b335fd97 100644 --- a/server/src/main/java/org/elasticsearch/index/store/StoreMetricsIndexInput.java +++ b/server/src/main/java/org/elasticsearch/index/store/StoreMetricsIndexInput.java @@ -14,14 +14,17 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; import org.apache.lucene.store.RandomAccessInput; +import org.elasticsearch.core.CheckedConsumer; +import org.elasticsearch.core.DirectAccessInput; import org.elasticsearch.simdvec.MemorySegmentAccessInputAccess; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Map; import java.util.Optional; import java.util.Set; -public class StoreMetricsIndexInput extends FilterIndexInput { +public class StoreMetricsIndexInput extends FilterIndexInput implements DirectAccessInput { final PluggableDirectoryMetricsHolder metricHolder; public static IndexInput create(String resourceDescription, IndexInput in, PluggableDirectoryMetricsHolder metricHolder) { @@ -88,6 +91,14 @@ public void prefetch(long offset, long length) throws IOException { in.prefetch(offset, length); } + @Override + public boolean withByteBufferSlice(long offset, long length, CheckedConsumer action) throws IOException { + if (in instanceof DirectAccessInput dai) { + return dai.withByteBufferSlice(offset, length, action); + } + return false; + } + @Override public Optional isLoaded() { return in.isLoaded(); diff --git a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java index fa2fe1bc5628c..b3594022c4e3c 100644 --- a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java +++ b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java @@ -33,6 +33,7 @@ import org.elasticsearch.common.util.concurrent.ThrottledTaskRunner; import org.elasticsearch.core.AbstractRefCounted; import org.elasticsearch.core.Assertions; +import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; @@ -43,6 +44,7 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.store.LuceneFilesExtensions; import org.elasticsearch.monitor.fs.FsProbe; +import org.elasticsearch.nativeaccess.CloseableByteBuffer; import org.elasticsearch.node.NodeRoleSettings; import org.elasticsearch.threadpool.ThreadPool; @@ -1064,6 +1066,35 @@ boolean tryRead(ByteBuffer buf, long offset) throws IOException { } } + /** + * Optimistically try to get a direct ByteBuffer slice from the region. + * The returned {@link CloseableByteBuffer} holds a reference to this region, + * preventing eviction while the buffer is in use. The caller must close it + * when done. + * @return a CloseableByteBuffer wrapping a read-only ByteBuffer slice, or null if not available + */ + CloseableByteBuffer tryGetByteBufferSlice(long offset, int length) { + SharedBytes.IO ioRef = nonVolatileIO(); + if (ioRef != null && tryIncRef()) { + ByteBuffer slice = ioRef.byteBufferSlice(blobCacheService.getRegionRelativePosition(offset), length); + if (slice != null && isEvicted() == false) { + return new CloseableByteBuffer() { + @Override + public ByteBuffer buffer() { + return slice; + } + + @Override + public void close() { + CacheFileRegion.this.decRef(); + } + }; + } + decRef(); + } + return null; + } + /** * Populates a range in cache if the range is not available nor pending to be available in cache. * @@ -1362,6 +1393,49 @@ public boolean tryRead(ByteBuffer buf, long offset) throws IOException { return res; } + CloseableByteBuffer tryGetByteBufferSlice(long offset, int length) { + assert assertOffsetsWithinFileLength(offset, length, this.length); + final int startRegion = getRegion(offset); + final long end = offset + length; + final int endRegion = getEndingRegion(end); + if (startRegion != endRegion) { + return null; + } + var fileRegion = lastAccessedRegion; + if (fileRegion != null && fileRegion.chunk.regionKey.region == startRegion) { + fileRegion.touch(); + } else { + fileRegion = cache.get(cacheKey, this.length, startRegion); + } + final var region = fileRegion.chunk; + if (region.tracker.checkAvailable(end - getRegionStart(startRegion)) == false) { + return null; + } + CloseableByteBuffer slice = region.tryGetByteBufferSlice(offset, length); + if (slice != null) { + lastAccessedRegion = fileRegion; + } + return slice; + } + + /** + * If a direct byte buffer view is available for the given range, passes it + * to {@code action} and returns {@code true}. Otherwise returns + * {@code false} without invoking the action. + */ + public boolean withByteBufferSlice(long offset, int length, CheckedConsumer action) throws IOException { + CloseableByteBuffer cbb = tryGetByteBufferSlice(offset, length); + if (cbb == null) { + return false; + } + try { + action.accept(cbb.buffer()); + return true; + } finally { + cbb.close(); + } + } + public int populateAndRead( final ByteRange rangeToWrite, final ByteRange rangeToRead, diff --git a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java index edee430e7f0b0..33a1cb96eaeb7 100644 --- a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java +++ b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java @@ -375,6 +375,22 @@ public int read(ByteBuffer dst, int position) throws IOException { return bytesRead; } + /** + * Returns a read-only ByteBuffer slice of the memory-mapped region, + * or {@code null} if not memory-mapped. + * + * @param position the starting position within the region, must be non-negative + * @param length the number of bytes, {@code position + length} must not exceed the region size + * @throws IllegalArgumentException if the position/length are out of bounds + */ + public ByteBuffer byteBufferSlice(int position, int length) { + if (mmap) { + checkOffsets(position, length); + return mappedByteBuffer.buffer().slice(position, length).asReadOnlyBuffer(); + } + return null; + } + @SuppressForbidden(reason = "Use positional writes on purpose") public int write(ByteBuffer src, int position) throws IOException { // check if writes are page size aligned for optimal performance diff --git a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java index 7c61da5475fd6..eb5f7959e11ae 100644 --- a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java +++ b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java @@ -2280,6 +2280,333 @@ public void fillCacheRange( } } + // Verifies that withByteBufferSlice returns false before data is populated, and provides + // a readable byte buffer with correct content after population. Single region of size(10), file size(8). + public void testWithByteBufferSlice() throws Exception { + final int regionSize = (int) size(10); + final long fileLength = size(8); // fits in a single region + Settings settings = Settings.builder() + .put(NODE_NAME_SETTING.getKey(), "node") + .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(50)).getStringRep()) + .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(regionSize).getStringRep()) + .put(SharedBlobCacheService.SHARED_CACHE_MMAP.getKey(), true) + .put("path.home", createTempDir()) + .build(); + final DeterministicTaskQueue taskQueue = new DeterministicTaskQueue(); + ExecutorService ioExecutor = Executors.newCachedThreadPool(); + try ( + NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings)); + var cacheService = new SharedBlobCacheService( + environment, + settings, + taskQueue.getThreadPool(), + ioExecutor, + BlobCacheMetrics.NOOP + ) + ) { + final var cacheKey = generateCacheKey(); + SharedBlobCacheService.CacheFile cacheFile = cacheService.getCacheFile( + cacheKey, + fileLength, + SharedBlobCacheService.CacheMissHandler.NOOP + ); + + // before populating, withByteBufferSlice should return false (data not available) + assertFalse(cacheFile.withByteBufferSlice(0, 100, slice -> fail("should not be invoked"))); + + // populate the cache with known data + byte[] testData = randomByteArrayOfLength((int) fileLength); + ByteBuffer writeBuffer = ByteBuffer.allocate(SharedBytes.PAGE_SIZE); + final int bytesRead = cacheFile.populateAndRead( + ByteRange.of(0L, fileLength), + ByteRange.of(0L, fileLength), + (channel, pos, relativePos, len) -> len, + (channel, channelPos, streamFactory, relativePos, len, progressUpdater, completionListener) -> { + SharedBytes.copyToCacheFileAligned( + channel, + new java.io.ByteArrayInputStream(testData, relativePos, len), + channelPos, + relativePos, + len, + progressUpdater, + writeBuffer.clear() + ); + ActionListener.completeWith(completionListener, () -> null); + }, + "test" + ); + assertThat(bytesRead, equalTo((int) fileLength)); + + // now withByteBufferSlice should provide a valid slice + int sliceOffset = randomIntBetween(0, (int) fileLength / 2); + int sliceLength = randomIntBetween(1, (int) fileLength - sliceOffset); + boolean available = cacheFile.withByteBufferSlice(sliceOffset, sliceLength, slice -> { + assertTrue(slice.isReadOnly()); + assertEquals(sliceLength, slice.remaining()); + byte[] sliceData = new byte[sliceLength]; + slice.get(sliceData); + for (int i = 0; i < sliceLength; i++) { + assertEquals(testData[sliceOffset + i], sliceData[i]); + } + }); + assertTrue(available); + } + ioExecutor.shutdown(); + } + + // Verifies that the byte buffer ref held during the callback prevents the region from being + // evicted. 2 regions of size(10), file size(8); eviction pressure is applied inside the callback. + public void testWithByteBufferSlicePreventsEviction() throws Exception { + final int regionSize = (int) size(10); + Settings settings = Settings.builder() + .put(NODE_NAME_SETTING.getKey(), "node") + .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(20)).getStringRep()) + .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(regionSize).getStringRep()) + .put(SharedBlobCacheService.SHARED_CACHE_MMAP.getKey(), true) + .put("path.home", createTempDir()) + .build(); + final DeterministicTaskQueue taskQueue = new DeterministicTaskQueue(); + ExecutorService ioExecutor = Executors.newCachedThreadPool(); + try ( + NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings)); + var cacheService = new SharedBlobCacheService( + environment, + settings, + taskQueue.getThreadPool(), + ioExecutor, + BlobCacheMetrics.NOOP + ) + ) { + assertEquals(2, cacheService.freeRegionCount()); + + // populate region 0 with known data for cacheKey1 + final long fileLength = size(8); // fits in one region + final var cacheKey1 = generateCacheKey(); + SharedBlobCacheService.CacheFile cacheFile1 = cacheService.getCacheFile( + cacheKey1, + fileLength, + SharedBlobCacheService.CacheMissHandler.NOOP + ); + byte[] testData = randomByteArrayOfLength((int) fileLength); + ByteBuffer writeBuffer = ByteBuffer.allocate(SharedBytes.PAGE_SIZE); + cacheFile1.populateAndRead( + ByteRange.of(0L, fileLength), + ByteRange.of(0L, fileLength), + (channel, pos, relativePos, len) -> len, + (channel, channelPos, streamFactory, relativePos, len, progressUpdater, completionListener) -> { + SharedBytes.copyToCacheFileAligned( + channel, + new java.io.ByteArrayInputStream(testData, relativePos, len), + channelPos, + relativePos, + len, + progressUpdater, + writeBuffer.clear() + ); + ActionListener.completeWith(completionListener, () -> null); + }, + "test" + ); + + // inside the callback, the ref is held — eviction should not reclaim the region + boolean available = cacheFile1.withByteBufferSlice(0, (int) fileLength, slice -> { + // fill the remaining region with a different key, using up all free regions + final var cacheKey2 = generateCacheKey(); + cacheService.get(cacheKey2, fileLength, 0); + + // now all regions are used; requesting yet another key triggers eviction pressure + final var cacheKey3 = generateCacheKey(); + cacheService.get(cacheKey3, fileLength, 0); + taskQueue.runAllRunnableTasks(); + + // the buffer should still contain the original data (region not evicted while ref held) + byte[] readBack = new byte[(int) fileLength]; + slice.get(readBack); + assertArrayEquals(testData, readBack); + }); + assertTrue(available); + + } + ioExecutor.shutdown(); + } + + // Verifies that withByteBufferSlice returns false and the callback is not invoked after a + // region has been evicted. 2 regions of size(10), file size(8); eviction forced by cache pressure. + public void testWithByteBufferSliceReturnsFalseAfterEviction() throws Exception { + final int regionSize = (int) size(10); + Settings settings = Settings.builder() + .put(NODE_NAME_SETTING.getKey(), "node") + .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(20)).getStringRep()) + .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(regionSize).getStringRep()) + .put(SharedBlobCacheService.SHARED_CACHE_MMAP.getKey(), true) + .put("path.home", createTempDir()) + .build(); + final DeterministicTaskQueue taskQueue = new DeterministicTaskQueue(); + try ( + NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings)); + var cacheService = new SharedBlobCacheService( + environment, + settings, + taskQueue.getThreadPool(), + EsExecutors.DIRECT_EXECUTOR_SERVICE, + BlobCacheMetrics.NOOP + ) + ) { + assertEquals(2, cacheService.freeRegionCount()); + + final long fileLength = size(8); // fits in one region + final var cacheKey1 = generateCacheKey(); + SharedBlobCacheService.CacheFile cacheFile1 = cacheService.getCacheFile( + cacheKey1, + fileLength, + SharedBlobCacheService.CacheMissHandler.NOOP + ); + + // populate the region + byte[] testData = randomByteArrayOfLength((int) fileLength); + ByteBuffer writeBuffer = ByteBuffer.allocate(SharedBytes.PAGE_SIZE); + cacheFile1.populateAndRead( + ByteRange.of(0L, fileLength), + ByteRange.of(0L, fileLength), + (channel, pos, relativePos, len) -> len, + (channel, channelPos, streamFactory, relativePos, len, progressUpdater, completionListener) -> { + SharedBytes.copyToCacheFileAligned( + channel, + new java.io.ByteArrayInputStream(testData, relativePos, len), + channelPos, + relativePos, + len, + progressUpdater, + writeBuffer.clear() + ); + ActionListener.completeWith(completionListener, () -> null); + }, + "test" + ); + + // confirm the slice is accessible before eviction + assertTrue(cacheFile1.withByteBufferSlice(0, (int) fileLength, slice -> {})); + + // fill the second region, then request a third key to force eviction of cacheKey1's region + cacheService.get(generateCacheKey(), fileLength, 0); + cacheService.get(generateCacheKey(), fileLength, 0); + taskQueue.runAllRunnableTasks(); + + // after eviction the action must not be invoked and the method must return false + boolean available = cacheFile1.withByteBufferSlice( + 0, + (int) fileLength, + slice -> { fail("action should not be invoked after eviction"); } + ); + assertFalse(available); + } + } + + // Verifies that withByteBufferSlice returns false when the requested range spans multiple + // regions. Regions of size(10), file size(25) spanning 3 regions; slice straddles the boundary. + public void testWithByteBufferSliceCrossRegionReturnsFalse() throws Exception { + final int regionSize = (int) size(10); + final long fileLength = size(25); // spans 3 regions + Settings settings = Settings.builder() + .put(NODE_NAME_SETTING.getKey(), "node") + .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(50)).getStringRep()) + .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(regionSize).getStringRep()) + .put(SharedBlobCacheService.SHARED_CACHE_MMAP.getKey(), true) + .put("path.home", createTempDir()) + .build(); + final DeterministicTaskQueue taskQueue = new DeterministicTaskQueue(); + try ( + NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings)); + var cacheService = new SharedBlobCacheService( + environment, + settings, + taskQueue.getThreadPool(), + taskQueue.getThreadPool().executor(ThreadPool.Names.GENERIC), + BlobCacheMetrics.NOOP + ) + ) { + final var cacheKey = generateCacheKey(); + SharedBlobCacheService.CacheFile cacheFile = cacheService.getCacheFile( + cacheKey, + fileLength, + SharedBlobCacheService.CacheMissHandler.NOOP + ); + + // request a slice that spans the region boundary (region 0 -> region 1) + // region 0 covers [0, regionSize), region 1 covers [regionSize, 2*regionSize) + int crossBoundaryOffset = regionSize - 100; + int crossBoundaryLength = 200; // crosses into region 1 + boolean available = cacheFile.withByteBufferSlice(crossBoundaryOffset, crossBoundaryLength, slice -> { + fail("action should not be invoked for cross-region slice"); + }); + assertFalse(available); + } + } + + // Verifies that withByteBufferSlice returns false when mmap is disabled, even after the + // region has been fully populated. Single region of size(10), file size(8), mmap=false. + public void testWithByteBufferSliceNoMmapReturnsFalse() throws Exception { + final int regionSize = (int) size(10); + final long fileLength = size(8); + Settings settings = Settings.builder() + .put(NODE_NAME_SETTING.getKey(), "node") + .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(50)).getStringRep()) + .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(regionSize).getStringRep()) + .put(SharedBlobCacheService.SHARED_CACHE_MMAP.getKey(), false) + .put("path.home", createTempDir()) + .build(); + final DeterministicTaskQueue taskQueue = new DeterministicTaskQueue(); + ExecutorService ioExecutor = Executors.newCachedThreadPool(); + try ( + NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings)); + var cacheService = new SharedBlobCacheService( + environment, + settings, + taskQueue.getThreadPool(), + ioExecutor, + BlobCacheMetrics.NOOP + ) + ) { + final var cacheKey = generateCacheKey(); + SharedBlobCacheService.CacheFile cacheFile = cacheService.getCacheFile( + cacheKey, + fileLength, + SharedBlobCacheService.CacheMissHandler.NOOP + ); + + // populate the cache + byte[] testData = randomByteArrayOfLength((int) fileLength); + ByteBuffer writeBuffer = ByteBuffer.allocate(SharedBytes.PAGE_SIZE); + cacheFile.populateAndRead( + ByteRange.of(0L, fileLength), + ByteRange.of(0L, fileLength), + (channel, pos, relativePos, len) -> len, + (channel, channelPos, streamFactory, relativePos, len, progressUpdater, completionListener) -> { + SharedBytes.copyToCacheFileAligned( + channel, + new java.io.ByteArrayInputStream(testData, relativePos, len), + channelPos, + relativePos, + len, + progressUpdater, + writeBuffer.clear() + ); + ActionListener.completeWith(completionListener, () -> null); + }, + "test" + ); + + // without mmap, withByteBufferSlice should return false even with data populated + boolean available = cacheFile.withByteBufferSlice( + 0, + 100, + slice -> { fail("action should not be invoked when mmap is not enabled"); } + ); + assertFalse(available); + } + ioExecutor.shutdown(); + } + private record TestCacheKey(ShardId shardId, String file) implements SharedBlobCacheService.KeyBase {} private static TestCacheKey randomTestCacheKey(ShardId shardId) { diff --git a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBytesTests.java b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBytesTests.java index 1f26935e24c83..e4bfef7336cbf 100644 --- a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBytesTests.java +++ b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBytesTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.blobcache.shared; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Assertions; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.PathUtils; import org.elasticsearch.env.Environment; @@ -22,6 +23,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.nullValue; public class SharedBytesTests extends ESTestCase { @@ -122,6 +124,105 @@ public void testCopyAllWith0Padding() throws Exception { } } + // Verifies that byteBufferSlice returns a read-only buffer with correct content when mmap + // is enabled. Randomized region count (1-4) and region size (1-16 pages). + public void testByteBufferSliceMmap() throws Exception { + int regions = randomIntBetween(1, 4); + int regionSize = randomIntBetween(1, 16) * SharedBytes.PAGE_SIZE; + var nodeSettings = Settings.builder() + .put(Node.NODE_NAME_SETTING.getKey(), "node") + .put("path.home", createTempDir()) + .putList(Environment.PATH_DATA_SETTING.getKey(), createTempDir().toString()) + .build(); + SharedBytes sharedBytes = null; + try (var nodeEnv = new NodeEnvironment(nodeSettings, TestEnvironment.newEnvironment(nodeSettings))) { + // mmap=true + sharedBytes = new SharedBytes(regions, regionSize, nodeEnv, ignored -> {}, ignored -> {}, true); + int region = randomIntBetween(0, regions - 1); + byte[] randomData = randomByteArrayOfLength(regionSize); + ByteBuffer tempBuffer = ByteBuffer.allocate(regionSize); + SharedBytes.copyToCacheFileAligned( + sharedBytes.getFileChannel(region), + new ByteArrayInputStream(randomData), + 0, + writtenBytesCount -> {}, + tempBuffer + ); + + SharedBytes.IO io = sharedBytes.getFileChannel(region); + + // byteBufferSlice returns a non-null read-only buffer with correct data + int sliceOffset = randomIntBetween(0, regionSize / 2); + int sliceLength = randomIntBetween(1, regionSize - sliceOffset); + ByteBuffer slice = io.byteBufferSlice(sliceOffset, sliceLength); + assertNotNull(slice); + assertTrue(slice.isReadOnly()); + assertEquals(sliceLength, slice.remaining()); + byte[] sliceData = new byte[sliceLength]; + slice.get(sliceData); + for (int i = 0; i < sliceLength; i++) { + assertEquals(randomData[sliceOffset + i], sliceData[i]); + } + } finally { + if (sharedBytes != null) { + sharedBytes.decRef(); + } + } + } + + // Verifies that byteBufferSlice returns null when mmap is disabled. + // Randomized region count (1-4) and region size (1-16 pages), mmap=false. + public void testByteBufferSliceNoMmap() throws Exception { + int regions = randomIntBetween(1, 4); + int regionSize = randomIntBetween(1, 16) * SharedBytes.PAGE_SIZE; + var nodeSettings = Settings.builder() + .put(Node.NODE_NAME_SETTING.getKey(), "node") + .put("path.home", createTempDir()) + .putList(Environment.PATH_DATA_SETTING.getKey(), createTempDir().toString()) + .build(); + SharedBytes sharedBytes = null; + try (var nodeEnv = new NodeEnvironment(nodeSettings, TestEnvironment.newEnvironment(nodeSettings))) { + // mmap=false + sharedBytes = new SharedBytes(regions, regionSize, nodeEnv, ignored -> {}, ignored -> {}, false); + int region = randomIntBetween(0, regions - 1); + SharedBytes.IO io = sharedBytes.getFileChannel(region); + + // byteBufferSlice returns null when not mmap'd + assertThat(io.byteBufferSlice(0, regionSize), nullValue()); + } finally { + if (sharedBytes != null) { + sharedBytes.decRef(); + } + } + } + + // Verifies that byteBufferSlice rejects out-of-bounds requests: offset+length exceeding + // region size, and negative offset. Single region of 4 pages, mmap=true. + public void testByteBufferSliceBoundsCheck() throws Exception { + int regions = 1; + int regionSize = 4 * SharedBytes.PAGE_SIZE; + var nodeSettings = Settings.builder() + .put(Node.NODE_NAME_SETTING.getKey(), "node") + .put("path.home", createTempDir()) + .putList(Environment.PATH_DATA_SETTING.getKey(), createTempDir().toString()) + .build(); + SharedBytes sharedBytes = null; + try (var nodeEnv = new NodeEnvironment(nodeSettings, TestEnvironment.newEnvironment(nodeSettings))) { + sharedBytes = new SharedBytes(regions, regionSize, nodeEnv, ignored -> {}, ignored -> {}, true); + SharedBytes.IO io = sharedBytes.getFileChannel(0); + + var expectedType = Assertions.ENABLED ? AssertionError.class : IllegalArgumentException.class; + // position + length exceeds region size + expectThrows(expectedType, () -> io.byteBufferSlice(regionSize - 10, 20)); + // negative position + expectThrows(expectedType, () -> io.byteBufferSlice(-1, 10)); + } finally { + if (sharedBytes != null) { + sharedBytes.decRef(); + } + } + } + /** * Test that mmap'd SharedBytes instances release their mapped memory on close, so that the OS * can reclaim disk space immediately. Without proper unmapping, each iteration leaks the cache diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java index ab8ae11a56ca8..04620a753c4fc 100644 --- a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java +++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java @@ -17,6 +17,8 @@ import org.elasticsearch.blobcache.common.ByteRange; import org.elasticsearch.blobcache.shared.SharedBlobCacheService; import org.elasticsearch.blobcache.shared.SharedBytes; +import org.elasticsearch.core.CheckedConsumer; +import org.elasticsearch.core.DirectAccessInput; import org.elasticsearch.index.snapshots.blobstore.BlobStoreIndexShardSnapshot.FileInfo; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots; @@ -28,7 +30,7 @@ import java.io.InputStream; import java.nio.ByteBuffer; -public final class FrozenIndexInput extends MetadataCachingIndexInput { +public final class FrozenIndexInput extends MetadataCachingIndexInput implements DirectAccessInput { private static final Logger logger = LogManager.getLogger(FrozenIndexInput.class); @@ -105,6 +107,11 @@ private FrozenIndexInput(FrozenIndexInput input) { this.cacheFile = input.cacheFile.copy(); } + @Override + public boolean withByteBufferSlice(long offset, long length, CheckedConsumer action) throws IOException { + return cacheFile.withByteBufferSlice(offset + this.offset, Math.toIntExact(length), action); + } + @Override protected void readWithoutBlobCache(ByteBuffer b) throws Exception { final long position = getAbsolutePosition(); diff --git a/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectoryFactory.java b/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectoryFactory.java index a3fba29958313..e13ac1639f2b1 100644 --- a/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectoryFactory.java +++ b/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectoryFactory.java @@ -85,11 +85,13 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardCopyOption; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; /** @@ -119,6 +121,196 @@ public static Directory newDirectory(Path path) { return new WriteOnceSnapshotDirectory(path); } + /** + * Returns a read-only {@link Directory} backed by searchable snapshot infrastructure, + * wrapping an existing Lucene index on disk. The original index files are read directly + * via {@link FsBlobContainer} (zero-copy, no heap buffering). + * + * @param indexPath path to the existing Lucene index directory + * @param workPath scratch directory for cache and node environment files + */ + public static Directory newDirectoryFromIndex(Path indexPath, Path workPath) throws IOException { + return new DiskBackedSnapshotDirectory(indexPath, workPath); + } + + // -- DiskBackedSnapshotDirectory -- + + static class DiskBackedSnapshotDirectory extends Directory { + private final SearchableSnapshotDirectory delegate; + + // infrastructure + private final ThreadPool threadPool; + private final CacheService cacheService; + private final ClusterService clusterService; + private final SharedBlobCacheService sharedBlobCacheService; + private final NodeEnvironment nodeEnvironment; + + DiskBackedSnapshotDirectory(Path indexPath, Path workPath) throws IOException { + boolean success = false; + ThreadPool tp = null; + NodeEnvironment ne = null; + ClusterService cs = null; + CacheService cas = null; + SharedBlobCacheService sbcs = null; + try { + // Enumerate index files, build snapshot metadata + List fileInfos = new ArrayList<>(); + long totalSize = 0; + try (FSDirectory fsDir = FSDirectory.open(indexPath)) { + for (String file : fsDir.listAll()) { + if (file.equals("write.lock")) continue; + long fileLength = fsDir.fileLength(file); + totalSize += fileLength; + StoreFileMetadata metadata = new StoreFileMetadata( + file, + fileLength, + "0", + IndexVersion.current().luceneVersion().toString() + ); + fileInfos.add(new BlobStoreIndexShardSnapshot.FileInfo(file, metadata, ByteSizeValue.ofBytes(fileLength + 1))); + } + } + + // FsBlobContainer reads directly from the index directory (blob name = file name) + FsBlobStore blobStore = new FsBlobStore(8192, indexPath, true); + BlobContainer blobContainer = new FsBlobContainer(blobStore, BlobPath.EMPTY, indexPath); + + BlobStoreIndexShardSnapshot snapshot = new BlobStoreIndexShardSnapshot("snapshotId", fileInfos, 0L, 0L, 0, 0L); + + long cacheSize = totalSize; + + SnapshotId snapshotId = new SnapshotId("_name", "_uuid"); + IndexId indexId = new IndexId("_name", "_uuid"); + ShardId shardId = new ShardId("_name", "_uuid", 0); + Path topDir = workPath.resolve(shardId.getIndex().getUUID()); + Path shardDir = topDir.resolve(Integer.toString(shardId.getId())); + ShardPath shardPath = new ShardPath(false, shardDir, shardDir, shardId); + Path cacheDir = Files.createDirectories(CacheService.resolveSnapshotCache(shardDir).resolve(snapshotId.getUUID())); + + tp = new SimpleThreadPool("tp", SearchableSnapshots.executorBuilders(Settings.EMPTY)); + ne = newNodeEnvironment(Settings.EMPTY, workPath, cacheSize); + cs = createClusterService(tp, clusterSettings()); + + cas = new CacheService(Settings.EMPTY, cs, tp, new PersistentCache(ne)); + cas.start(); + sbcs = defaultFrozenCacheService(tp, ne, workPath, cacheSize); + + SearchableSnapshotDirectory dir = new SearchableSnapshotDirectory( + () -> blobContainer, + () -> snapshot, + new NoopBlobStoreCacheService(tp), + "_repo", + snapshotId, + indexId, + shardId, + buildIndexSettings(), + () -> 0L, + cas, + cacheDir, + shardPath, + tp, + sbcs + ); + + // loadSnapshot asserts ThreadPool.assertCurrentThreadPool(GENERIC) + SearchableSnapshotRecoveryState recoveryState = createRecoveryState(true); + PlainActionFuture f = new PlainActionFuture<>(); + CountDownLatch latch = new CountDownLatch(1); + tp.generic().execute(() -> { + try { + dir.loadSnapshot(recoveryState, () -> false, f); + } finally { + latch.countDown(); + } + }); + latch.await(); + f.get(); + + this.delegate = dir; + this.threadPool = tp; + this.nodeEnvironment = ne; + this.clusterService = cs; + this.cacheService = cas; + this.sharedBlobCacheService = sbcs; + success = true; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Interrupted while loading snapshot", e); + } catch (Exception e) { + throw new IOException("Failed to create searchable snapshot directory", e); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(sbcs, cs, cas, ne); + if (tp != null) { + ThreadPool.terminate(tp, 10, TimeUnit.SECONDS); + } + } + } + } + + @Override + public String[] listAll() throws IOException { + return delegate.listAll(); + } + + @Override + public void deleteFile(String name) { + throw new UnsupportedOperationException("read-only directory"); + } + + @Override + public long fileLength(String name) throws IOException { + return delegate.fileLength(name); + } + + @Override + public IndexOutput createOutput(String name, IOContext context) { + throw new UnsupportedOperationException("read-only directory"); + } + + @Override + public IndexOutput createTempOutput(String prefix, String suffix, IOContext context) { + throw new UnsupportedOperationException("read-only directory"); + } + + @Override + public void sync(Collection names) { + throw new UnsupportedOperationException("read-only directory"); + } + + @Override + public void syncMetaData() { + throw new UnsupportedOperationException("read-only directory"); + } + + @Override + public void rename(String source, String dest) { + throw new UnsupportedOperationException("read-only directory"); + } + + @Override + public IndexInput openInput(String name, IOContext context) throws IOException { + return delegate.openInput(name, context); + } + + @Override + public Lock obtainLock(String name) { + throw new UnsupportedOperationException("read-only directory"); + } + + @Override + public Set getPendingDeletions() { + return Set.of(); + } + + @Override + public void close() throws IOException { + IOUtils.closeWhileHandlingException(sharedBlobCacheService, clusterService, cacheService, nodeEnvironment); + delegate.close(); + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + } + } + // -- WriteOnceSnapshotDirectory -- static class WriteOnceSnapshotDirectory extends Directory { @@ -286,7 +478,7 @@ private void materializeSnapshotImpl() throws IOException { ); // load the snapshot so it's ready for reads - SearchableSnapshotRecoveryState recoveryState = createRecoveryState(false); + SearchableSnapshotRecoveryState recoveryState = createRecoveryState(true); final PlainActionFuture f = new PlainActionFuture<>(); delegate.loadSnapshot(recoveryState, () -> false, f); try {