diff --git a/docs/changelog/141718.yaml b/docs/changelog/141718.yaml
new file mode 100644
index 0000000000000..2c957dd6bc0ba
--- /dev/null
+++ b/docs/changelog/141718.yaml
@@ -0,0 +1,5 @@
+area: Vector Search
+issues: []
+pr: 141718
+summary: Enable zero-copy SIMD vector scoring on searchable snapshots (frozen tier)
+type: enhancement
diff --git a/libs/core/src/main/java/org/elasticsearch/core/DirectAccessInput.java b/libs/core/src/main/java/org/elasticsearch/core/DirectAccessInput.java
new file mode 100644
index 0000000000000..173e7b0317770
--- /dev/null
+++ b/libs/core/src/main/java/org/elasticsearch/core/DirectAccessInput.java
@@ -0,0 +1,40 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.core;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+/**
+ * An optional interface that an IndexInput can implement to provide direct
+ * access to the underlying data as a {@link ByteBuffer}. This enables
+ * zero-copy access to memory-mapped data for SIMD-accelerated vector scoring.
+ *
+ *
The byte buffer is passed to the caller's action and is only valid for
+ * the duration of that call. All ref-counting and resource releases, if any,
+ * is handled internally.
+ */
+public interface DirectAccessInput {
+
+ /**
+ * If a direct byte buffer view is available for the given range, passes it
+ * to {@code action} and returns {@code true}. Otherwise returns
+ * {@code false} without invoking the action.
+ *
+ *
The byte buffer is read-only and valid only for the duration of the
+ * action. Callers must not retain references to it after the action returns.
+ *
+ * @param offset the byte offset within the input
+ * @param length the number of bytes requested
+ * @param action the action to perform with the byte buffer
+ * @return {@code true} if a buffer was available and the action was invoked
+ */
+ boolean withByteBufferSlice(long offset, long length, CheckedConsumer action) throws IOException;
+}
diff --git a/libs/native/src/main/java/module-info.java b/libs/native/src/main/java/module-info.java
index dc74ecb5c6329..cb22d59fb54b4 100644
--- a/libs/native/src/main/java/module-info.java
+++ b/libs/native/src/main/java/module-info.java
@@ -20,8 +20,10 @@
to
org.elasticsearch.server,
org.elasticsearch.blobcache,
+ org.elasticsearch.searchablesnapshots,
org.elasticsearch.simdvec,
- org.elasticsearch.systemd;
+ org.elasticsearch.systemd,
+ org.elasticsearch.xpack.stateless;
uses NativeLibraryProvider;
diff --git a/libs/simdvec/src/main/java/module-info.java b/libs/simdvec/src/main/java/module-info.java
index 8cf2bac6a9504..deafe7f482bdd 100644
--- a/libs/simdvec/src/main/java/module-info.java
+++ b/libs/simdvec/src/main/java/module-info.java
@@ -26,9 +26,10 @@
* at runtime is selected by the multi-release classloader.
*/
module org.elasticsearch.simdvec {
+ requires org.elasticsearch.base;
+ requires org.elasticsearch.logging;
requires org.elasticsearch.nativeaccess;
requires org.apache.lucene.core;
- requires org.elasticsearch.logging;
exports org.elasticsearch.simdvec to org.elasticsearch.server;
}
diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/MemorySegmentAccessInputAccess.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/MemorySegmentAccessInputAccess.java
index 85618edc8da42..29e11725e2e2b 100644
--- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/MemorySegmentAccessInputAccess.java
+++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/MemorySegmentAccessInputAccess.java
@@ -15,8 +15,10 @@
import java.util.Objects;
public interface MemorySegmentAccessInputAccess {
+ /** Returns the underlying {@link MemorySegmentAccessInput}, or {@code null} if not available. */
MemorySegmentAccessInput get();
+ /** Unwraps to the underlying {@link MemorySegmentAccessInput} if available, otherwise returns the input unchanged. */
static IndexInput unwrap(IndexInput input) {
MemorySegmentAccessInput memorySeg = input instanceof MemorySegmentAccessInputAccess msaia ? msaia.get() : null;
return Objects.requireNonNullElse((IndexInput) memorySeg, input);
diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
index a4acb878acdee..80c1d4fb7eae7 100644
--- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
+++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
@@ -33,6 +33,14 @@ public static ESVectorizationProvider getInstance() {
/** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension, int bulkSize) throws IOException;
+ /**
+ * Create a new {@link ESNextOSQVectorsScorer} for the given {@link IndexInput}.
+ * The input should be unwrapped before calling this method. If the input is
+ * still a {@code FilterIndexInput} that does not implement
+ * {@code MemorySegmentAccessInput} or {@code DirectAccessInput}, an
+ * {@link IllegalArgumentException} is thrown. Non-wrapper inputs (e.g.
+ * {@code ByteBuffersIndexInput}) are accepted and use a heap-copy fallback.
+ */
public abstract ESNextOSQVectorsScorer newESNextOSQVectorsScorer(
IndexInput input,
byte queryBits,
@@ -42,7 +50,10 @@ public abstract ESNextOSQVectorsScorer newESNextOSQVectorsScorer(
int bulkSize
) throws IOException;
- /** Create a new {@link ES92Int7VectorsScorer} for the given {@link IndexInput}. */
+ /**
+ * Create a new {@link ES92Int7VectorsScorer} for the given {@link IndexInput}.
+ * See {@link #newESNextOSQVectorsScorer} for input type requirements.
+ */
public abstract ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension, int bulkSize) throws IOException;
// visible for tests
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/IndexInputUtils.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/IndexInputUtils.java
new file mode 100644
index 0000000000000..41985352a1af0
--- /dev/null
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/IndexInputUtils.java
@@ -0,0 +1,114 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+package org.elasticsearch.simdvec.internal;
+
+import org.apache.lucene.store.FilterIndexInput;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.MemorySegmentAccessInput;
+import org.elasticsearch.core.CheckedFunction;
+import org.elasticsearch.core.DirectAccessInput;
+
+import java.io.IOException;
+import java.lang.foreign.MemorySegment;
+import java.util.function.IntFunction;
+
+/**
+ * Utility for obtaining a {@link MemorySegment} view of data in an
+ * {@link IndexInput} and passing it to a caller-supplied action. The
+ * segment may come from a {@link MemorySegmentAccessInput} (mmap),
+ * a direct {@link java.nio.ByteBuffer} view (e.g. blob-cache), or a
+ * heap copy as a last resort.
+ *
+ * All resource management (ref-counting, buffer release) is handled
+ * internally — callers never see a closeable resource.
+ */
+public final class IndexInputUtils {
+
+ private IndexInputUtils() {}
+
+ /**
+ * Obtains a memory segment for the next {@code length} bytes of the
+ * index input, passes it to {@code action}, and returns the result.
+ * The position of the index input is advanced by {@code length}.
+ *
+ *
This method first tries to obtain a slice via
+ * {@link MemorySegmentAccessInput#segmentSliceOrNull}. If that
+ * returns {@code null}, it tries a direct {@link java.nio.ByteBuffer}
+ * view via {@link DirectAccessInput}. As a last resort it copies the
+ * data onto the heap using a byte array obtained from
+ * {@code scratchSupplier}.
+ *
+ *
The memory segment passed to {@code action} is valid only for
+ * the duration of the call. Callers must not retain references to it.
+ *
+ * @param in the index input positioned at the data to read
+ * @param length the number of bytes to read
+ * @param scratchSupplier supplies a byte array of at least the requested
+ * length, used only on the heap-copy fallback path
+ * @param action the function to apply to the memory segment
+ * @return the result of applying {@code action}
+ */
+ public static R withSlice(
+ IndexInput in,
+ long length,
+ IntFunction scratchSupplier,
+ CheckedFunction action
+ ) throws IOException {
+ checkInputType(in);
+ if (in instanceof MemorySegmentAccessInput msai) {
+ long offset = in.getFilePointer();
+ MemorySegment slice = msai.segmentSliceOrNull(offset, length);
+ if (slice != null) {
+ in.skipBytes(length);
+ return action.apply(slice);
+ }
+ }
+ if (in instanceof DirectAccessInput dai) {
+ long offset = in.getFilePointer();
+ @SuppressWarnings("unchecked")
+ R[] result = (R[]) new Object[1];
+ boolean available = dai.withByteBufferSlice(offset, length, bb -> {
+ in.skipBytes(length);
+ result[0] = action.apply(MemorySegment.ofBuffer(bb));
+ });
+ if (available) {
+ return result[0];
+ }
+ }
+ return action.apply(copyOnHeap(in, Math.toIntExact(length), scratchSupplier));
+ }
+
+ /**
+ * Checks that a {@link FilterIndexInput} wrapper also implements
+ * {@link MemorySegmentAccessInput} or {@link DirectAccessInput},
+ * so that zero-copy access is preserved through the wrapper chain.
+ */
+ public static void checkInputType(IndexInput in) {
+ if (in instanceof FilterIndexInput && (in instanceof MemorySegmentAccessInput || in instanceof DirectAccessInput) == false) {
+ throw new IllegalArgumentException(
+ "IndexInput is a FilterIndexInput ("
+ + in.getClass().getName()
+ + ") that does not implement MemorySegmentAccessInput or DirectAccessInput. "
+ + "Ensure the wrapper implements DirectAccessInput or is unwrapped before constructing the scorer."
+ );
+ }
+ }
+
+ /**
+ * Reads the given number of bytes from the current position of the
+ * given IndexInput into a heap-backed memory segment. The returned
+ * segment is sliced to exactly {@code bytesToRead} bytes, even if
+ * the underlying array is larger.
+ */
+ private static MemorySegment copyOnHeap(IndexInput in, int bytesToRead, IntFunction scratchSupplier) throws IOException {
+ byte[] buf = scratchSupplier.apply(bytesToRead);
+ in.readBytes(buf, 0, bytesToRead);
+ return MemorySegment.ofArray(buf).asSlice(0, bytesToRead);
+ }
+}
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java
index 117aad9a85290..c2a009473ce83 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java
@@ -12,13 +12,12 @@
import org.apache.lucene.store.IndexInput;
import java.io.IOException;
-import java.lang.foreign.MemorySegment;
/** Panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/
public final class MemorySegmentES92Int7VectorsScorer extends MemorySegmentES92PanamaInt7VectorsScorer {
- public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, int bulkSize, MemorySegment memorySegment) {
- super(in, dimensions, bulkSize, memorySegment);
+ public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, int bulkSize) {
+ super(in, dimensions, bulkSize);
}
@Override
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92PanamaInt7VectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92PanamaInt7VectorsScorer.java
index 535c6a3b8eb89..51ce3b9d4a01c 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92PanamaInt7VectorsScorer.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92PanamaInt7VectorsScorer.java
@@ -24,6 +24,7 @@
import java.io.IOException;
import java.lang.foreign.MemorySegment;
+import java.lang.foreign.ValueLayout;
import java.nio.ByteOrder;
import static java.nio.ByteOrder.LITTLE_ENDIAN;
@@ -58,44 +59,55 @@ abstract class MemorySegmentES92PanamaInt7VectorsScorer extends ES92Int7VectorsS
INT_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(VECTOR_BITSIZE));
}
- protected final MemorySegment memorySegment;
+ private byte[] scratch;
- protected MemorySegmentES92PanamaInt7VectorsScorer(IndexInput in, int dimensions, int bulkSize, MemorySegment memorySegment) {
+ protected MemorySegmentES92PanamaInt7VectorsScorer(IndexInput in, int dimensions, int bulkSize) {
super(in, dimensions, bulkSize);
- this.memorySegment = memorySegment;
+ IndexInputUtils.checkInputType(in);
+ }
+
+ protected byte[] getScratch(int len) {
+ if (scratch == null || scratch.length < len) {
+ scratch = new byte[len];
+ }
+ return scratch;
}
protected long panamaInt7DotProduct(byte[] q) throws IOException {
assert dimensions == q.length;
+ return IndexInputUtils.withSlice(in, dimensions, this::getScratch, segment -> panamaInt7DotProductImpl(q, segment, dimensions));
+ }
+
+ private static long panamaInt7DotProductImpl(byte[] q, MemorySegment segment, int dimensions) {
int i = 0;
int res = 0;
// only vectorize if we'll at least enter the loop a single time
if (dimensions >= 16) {
// compute vectorized dot product consistent with VPDPBUSD instruction
if (VECTOR_BITSIZE >= 512) {
- i += BYTE_SPECIES_128.loopBound(dimensions);
- res += dotProductBody512(q, i);
+ int limit = BYTE_SPECIES_128.loopBound(dimensions);
+ res += dotProductBody512Impl(q, segment, 0, limit);
+ i = limit;
} else if (VECTOR_BITSIZE == 256) {
- i += BYTE_SPECIES_64.loopBound(dimensions);
- res += dotProductBody256(q, i);
+ int limit = BYTE_SPECIES_64.loopBound(dimensions);
+ res += dotProductBody256Impl(q, segment, 0, limit);
+ i = limit;
} else {
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
- i += BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length());
- res += dotProductBody128(q, i);
+ int limit = BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length());
+ res += dotProductBody128Impl(q, segment, 0, limit);
+ i = limit;
}
- // scalar tail
- while (i < dimensions) {
- res += in.readByte() * q[i++];
- }
- return res;
- } else {
- return super.int7DotProduct(q);
}
+ // scalar tail
+ for (; i < dimensions; i++) {
+ res += segment.get(ValueLayout.JAVA_BYTE, i) * q[i];
+ }
+ return res;
}
- private int dotProductBody512(byte[] q, int limit) throws IOException {
+ private static int dotProductBody512Impl(byte[] q, MemorySegment memorySegment, long offset, int limit) {
IntVector acc = IntVector.zero(INT_SPECIES_512);
- long offset = in.getFilePointer();
for (int i = 0; i < limit; i += BYTE_SPECIES_128.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i);
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN);
@@ -109,15 +121,11 @@ private int dotProductBody512(byte[] q, int limit) throws IOException {
Vector prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0);
acc = acc.add(prod32);
}
-
- in.seek(offset + limit); // advance the input stream
- // reduce
return acc.reduceLanes(ADD);
}
- private int dotProductBody256(byte[] q, int limit) throws IOException {
+ private static int dotProductBody256Impl(byte[] q, MemorySegment memorySegment, long offset, int limit) {
IntVector acc = IntVector.zero(INT_SPECIES_256);
- long offset = in.getFilePointer();
for (int i = 0; i < limit; i += BYTE_SPECIES_64.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
@@ -127,14 +135,11 @@ private int dotProductBody256(byte[] q, int limit) throws IOException {
Vector vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0);
acc = acc.add(va32.mul(vb32));
}
- in.seek(offset + limit);
- // reduce
return acc.reduceLanes(ADD);
}
- private int dotProductBody128(byte[] q, int limit) throws IOException {
+ private static int dotProductBody128Impl(byte[] q, MemorySegment memorySegment, long offset, int limit) {
IntVector acc = IntVector.zero(IntVector.SPECIES_128);
- long offset = in.getFilePointer();
// 4 bytes at a time (re-loading half the vector each time!)
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
// load 8 bytes
@@ -149,109 +154,72 @@ private int dotProductBody128(byte[] q, int limit) throws IOException {
// 32-bit add
acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0));
}
- in.seek(offset + limit);
- // reduce
return acc.reduceLanes(ADD);
}
protected void panamaInt7DotProductBulk(byte[] q, int count, float[] scores) throws IOException {
assert dimensions == q.length;
+ IndexInputUtils.withSlice(in, (long) dimensions * count, this::getScratch, segment -> {
+ panamaInt7DotProductBulkImpl(q, segment, dimensions, count, scores);
+ return null;
+ });
+ }
+
+ private static void panamaInt7DotProductBulkImpl(byte[] q, MemorySegment memorySegment, int dimensions, int count, float[] scores) {
// only vectorize if we'll at least enter the loop a single time
if (dimensions >= 16) {
// compute vectorized dot product consistent with VPDPBUSD instruction
if (VECTOR_BITSIZE >= 512) {
- dotProductBody512Bulk(q, count, scores);
+ dotProductBulkVectorized512(q, memorySegment, dimensions, count, scores);
} else if (VECTOR_BITSIZE == 256) {
- dotProductBody256Bulk(q, count, scores);
+ dotProductBulkVectorized256(q, memorySegment, dimensions, count, scores);
} else {
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
- dotProductBody128Bulk(q, count, scores);
+ dotProductBulkVectorized128(q, memorySegment, dimensions, count, scores);
}
} else {
- super.int7DotProductBulk(q, count, scores);
+ for (int iter = 0; iter < count; iter++) {
+ long base = (long) iter * dimensions;
+ long res = 0;
+ for (int i = 0; i < dimensions; i++) {
+ res += memorySegment.get(ValueLayout.JAVA_BYTE, base + i) * q[i];
+ }
+ scores[iter] = res;
+ }
}
}
- private void dotProductBody512Bulk(byte[] q, int count, float[] scores) throws IOException {
+ private static void dotProductBulkVectorized512(byte[] q, MemorySegment memorySegment, int dimensions, int count, float[] scores) {
int limit = BYTE_SPECIES_128.loopBound(dimensions);
for (int iter = 0; iter < count; iter++) {
- IntVector acc = IntVector.zero(INT_SPECIES_512);
- long offset = in.getFilePointer();
- int i = 0;
- for (; i < limit; i += BYTE_SPECIES_128.length()) {
- ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i);
- ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN);
-
- // 16-bit multiply: avoid AVX-512 heavy multiply on zmm
- Vector va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0);
- Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0);
- Vector prod16 = va16.mul(vb16);
-
- // 32-bit add
- Vector prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0);
- acc = acc.add(prod32);
- }
-
- in.seek(offset + limit); // advance the input stream
- // reduce
- long res = acc.reduceLanes(ADD);
- for (; i < dimensions; i++) {
- res += in.readByte() * q[i];
+ long base = (long) iter * dimensions;
+ long res = dotProductBody512Impl(q, memorySegment, base, limit);
+ for (int i = limit; i < dimensions; i++) {
+ res += memorySegment.get(ValueLayout.JAVA_BYTE, base + i) * q[i];
}
scores[iter] = res;
}
}
- private void dotProductBody256Bulk(byte[] q, int count, float[] scores) throws IOException {
+ private static void dotProductBulkVectorized256(byte[] q, MemorySegment memorySegment, int dimensions, int count, float[] scores) {
int limit = BYTE_SPECIES_128.loopBound(dimensions);
for (int iter = 0; iter < count; iter++) {
- IntVector acc = IntVector.zero(INT_SPECIES_256);
- long offset = in.getFilePointer();
- int i = 0;
- for (; i < limit; i += BYTE_SPECIES_64.length()) {
- ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
- ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
-
- // 32-bit multiply and add into accumulator
- Vector va32 = va8.convertShape(B2I, INT_SPECIES_256, 0);
- Vector vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0);
- acc = acc.add(va32.mul(vb32));
- }
- in.seek(offset + limit);
- // reduce
- long res = acc.reduceLanes(ADD);
- for (; i < dimensions; i++) {
- res += in.readByte() * q[i];
+ long base = (long) iter * dimensions;
+ long res = dotProductBody256Impl(q, memorySegment, base, limit);
+ for (int i = limit; i < dimensions; i++) {
+ res += memorySegment.get(ValueLayout.JAVA_BYTE, base + i) * q[i];
}
scores[iter] = res;
}
}
- private void dotProductBody128Bulk(byte[] q, int count, float[] scores) throws IOException {
+ private static void dotProductBulkVectorized128(byte[] q, MemorySegment memorySegment, int dimensions, int count, float[] scores) {
int limit = BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length());
for (int iter = 0; iter < count; iter++) {
- IntVector acc = IntVector.zero(IntVector.SPECIES_128);
- long offset = in.getFilePointer();
- // 4 bytes at a time (re-loading half the vector each time!)
- int i = 0;
- for (; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
- // load 8 bytes
- ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
- ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
-
- // process first "half" only: 16-bit multiply
- Vector va16 = va8.convert(B2S, 0);
- Vector vb16 = vb8.convert(B2S, 0);
- Vector prod16 = va16.mul(vb16);
-
- // 32-bit add
- acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0));
- }
- in.seek(offset + limit);
- // reduce
- long res = acc.reduceLanes(ADD);
- for (; i < dimensions; i++) {
- res += in.readByte() * q[i];
+ long base = (long) iter * dimensions;
+ long res = dotProductBody128Impl(q, memorySegment, base, limit);
+ for (int i = limit; i < dimensions; i++) {
+ res += memorySegment.get(ValueLayout.JAVA_BYTE, base + i) * q[i];
}
scores[iter] = res;
}
@@ -267,30 +235,55 @@ protected void applyCorrectionsBulk(
float[] scores,
int bulkSize
) throws IOException {
+ IndexInputUtils.withSlice(in, 16L * bulkSize, this::getScratch, memorySegment -> {
+ applyCorrectionsBulkImpl(
+ memorySegment,
+ queryAdditionalCorrection,
+ similarityFunction,
+ centroidDp,
+ scores,
+ bulkSize,
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum,
+ dimensions
+ );
+ return null;
+ });
+ }
+
+ private static void applyCorrectionsBulkImpl(
+ MemorySegment memorySegment,
+ float queryAdditionalCorrection,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores,
+ int bulkSize,
+ float queryLowerInterval,
+ float queryUpperInterval,
+ int queryComponentSum,
+ int dimensions
+ ) {
int limit = FLOAT_SPECIES.loopBound(bulkSize);
int i = 0;
- long offset = in.getFilePointer();
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE;
float y1 = queryComponentSum;
for (; i < limit; i += FLOAT_SPECIES.length()) {
- var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
- var lx = FloatVector.fromMemorySegment(
- FLOAT_SPECIES,
- memorySegment,
- offset + 4 * bulkSize + i * Float.BYTES,
- ByteOrder.LITTLE_ENDIAN
- ).sub(ax).mul(SEVEN_BIT_SCALE);
+ var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
+ var lx = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, 4 * bulkSize + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN)
+ .sub(ax)
+ .mul(SEVEN_BIT_SCALE);
var targetComponentSums = IntVector.fromMemorySegment(
INT_SPECIES,
memorySegment,
- offset + 8 * bulkSize + i * Integer.BYTES,
+ 8 * bulkSize + i * Integer.BYTES,
ByteOrder.LITTLE_ENDIAN
).convert(VectorOperators.I2F, 0);
var additionalCorrections = FloatVector.fromMemorySegment(
FLOAT_SPECIES,
memorySegment,
- offset + 12 * bulkSize + i * Float.BYTES,
+ 12 * bulkSize + i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN
);
var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i);
@@ -328,17 +321,11 @@ protected void applyCorrectionsBulk(
var floatVectorMask = FLOAT_SPECIES.indexInRange(i, bulkSize);
var intVectorMask = INT_SPECIES.indexInRange(i, bulkSize);
- var ax = FloatVector.fromMemorySegment(
- FLOAT_SPECIES,
- memorySegment,
- offset + (long) i * Float.BYTES,
- ByteOrder.LITTLE_ENDIAN,
- floatVectorMask
- );
+ var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, (long) i * Float.BYTES, LITTLE_ENDIAN, floatVectorMask);
var upper = FloatVector.fromMemorySegment(
FLOAT_SPECIES,
memorySegment,
- offset + 4L * bulkSize + (long) i * Float.BYTES,
+ 4L * bulkSize + (long) i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN,
floatVectorMask
);
@@ -347,7 +334,7 @@ protected void applyCorrectionsBulk(
var targetComponentSums = IntVector.fromMemorySegment(
INT_SPECIES,
memorySegment,
- offset + 8L * bulkSize + (long) i * Integer.BYTES,
+ 8L * bulkSize + (long) i * Integer.BYTES,
ByteOrder.LITTLE_ENDIAN,
intVectorMask
).convert(VectorOperators.I2F, 0);
@@ -355,7 +342,7 @@ protected void applyCorrectionsBulk(
var additionalCorrections = FloatVector.fromMemorySegment(
FLOAT_SPECIES,
memorySegment,
- offset + 12L * bulkSize + (long) i * Float.BYTES,
+ 12L * bulkSize + (long) i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN,
floatVectorMask
);
@@ -385,6 +372,6 @@ protected void applyCorrectionsBulk(
}
}
}
- in.seek(offset + 16L * bulkSize);
}
+
}
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
index 8a010c23abe99..ac79f585f08c5 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
@@ -40,6 +40,14 @@ public static ESVectorizationProvider getInstance() {
/** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension, int bulkSize) throws IOException;
+ /**
+ * Create a new {@link ESNextOSQVectorsScorer} for the given {@link IndexInput}.
+ * The input should be unwrapped before calling this method. If the input is
+ * still a {@code FilterIndexInput} that does not implement
+ * {@code MemorySegmentAccessInput} or {@code DirectAccessInput}, an
+ * {@link IllegalArgumentException} is thrown. Non-wrapper inputs (e.g.
+ * {@code ByteBuffersIndexInput}) are accepted and use a heap-copy fallback.
+ */
public abstract ESNextOSQVectorsScorer newESNextOSQVectorsScorer(
IndexInput input,
byte queryBits,
@@ -49,7 +57,10 @@ public abstract ESNextOSQVectorsScorer newESNextOSQVectorsScorer(
int bulkSize
) throws IOException;
- /** Create a new {@link ES92Int7VectorsScorer} for the given {@link IndexInput}. */
+ /**
+ * Create a new {@link ES92Int7VectorsScorer} for the given {@link IndexInput}.
+ * See {@link #newESNextOSQVectorsScorer} for input type requirements.
+ */
public abstract ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension, int bulkSize) throws IOException;
// visible for tests
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java
index d8fa469c41c05..7d4124c4adf6a 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java
@@ -18,6 +18,7 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.simdvec.internal.IndexInputUtils;
import java.io.IOException;
import java.lang.foreign.Arena;
@@ -33,8 +34,8 @@
/** Panamized scorer for quantized vectors stored as a {@link MemorySegment}. */
final class MSBitToInt4ESNextOSQVectorsScorer extends MemorySegmentESNextOSQVectorsScorer.MemorySegmentScorer {
- MSBitToInt4ESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize, MemorySegment memorySegment) {
- super(in, dimensions, dataLength, bulkSize, memorySegment);
+ MSBitToInt4ESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize) {
+ super(in, dimensions, dataLength, bulkSize);
}
@Override
@@ -56,9 +57,10 @@ public long quantizeScore(byte[] q) throws IOException {
}
private long nativeQuantizeScore(byte[] q) throws IOException {
- long offset = in.getFilePointer();
- var datasetMemorySegment = memorySegment.asSlice(offset, length);
+ return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> nativeQuantizeScoreImpl(q, segment, length));
+ }
+ private static long nativeQuantizeScoreImpl(byte[] q, MemorySegment datasetMemorySegment, int length) {
final long qScore;
if (SUPPORTS_HEAP_SEGMENTS) {
var queryMemorySegment = MemorySegment.ofArray(q);
@@ -70,29 +72,31 @@ private long nativeQuantizeScore(byte[] q) throws IOException {
qScore = dotProductD1Q4(datasetMemorySegment, queryMemorySegment, length);
}
}
- in.skipBytes(length);
return qScore;
}
private long quantizeScore256(byte[] q) throws IOException {
+ return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> quantizeScore256Impl(q, segment, length));
+ }
+
+ private static long quantizeScore256Impl(byte[] q, MemorySegment memorySegment, int length) {
long subRet0 = 0;
long subRet1 = 0;
long subRet2 = 0;
long subRet3 = 0;
int i = 0;
- long offset = in.getFilePointer();
if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
int limit = ByteVector.SPECIES_256.loopBound(length);
var sum0 = LongVector.zero(LONG_SPECIES_256);
var sum1 = LongVector.zero(LONG_SPECIES_256);
var sum2 = LongVector.zero(LONG_SPECIES_256);
var sum3 = LongVector.zero(LONG_SPECIES_256);
- for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) {
+ for (; i < limit; i += ByteVector.SPECIES_256.length()) {
var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs();
var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs();
var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs();
var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs();
- var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, i, ByteOrder.LITTLE_ENDIAN);
sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
@@ -110,12 +114,12 @@ private long quantizeScore256(byte[] q) throws IOException {
var sum2 = LongVector.zero(LONG_SPECIES_128);
var sum3 = LongVector.zero(LONG_SPECIES_128);
int limit = ByteVector.SPECIES_128.loopBound(length);
- for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) {
+ for (; i < limit; i += ByteVector.SPECIES_128.length()) {
var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs();
var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs();
var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs();
var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs();
- var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN);
sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
@@ -127,23 +131,22 @@ private long quantizeScore256(byte[] q) throws IOException {
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
}
// process scalar tail
- in.seek(offset);
for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
- final long value = in.readLong();
+ final long value = memorySegment.get(LAYOUT_LE_LONG, i);
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
}
for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
- final int value = in.readInt();
+ final int value = memorySegment.get(LAYOUT_LE_INT, i);
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
}
for (; i < length; i++) {
- int dValue = in.readByte() & 0xFF;
+ final int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF;
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
@@ -153,20 +156,22 @@ private long quantizeScore256(byte[] q) throws IOException {
}
private long quantizeScore128(byte[] q) throws IOException {
+ return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> quantizeScore128Impl(q, segment, length));
+ }
+
+ private static long quantizeScore128Impl(byte[] q, MemorySegment memorySegment, int length) {
long subRet0 = 0;
long subRet1 = 0;
long subRet2 = 0;
long subRet3 = 0;
int i = 0;
- long offset = in.getFilePointer();
-
var sum0 = IntVector.zero(INT_SPECIES_128);
var sum1 = IntVector.zero(INT_SPECIES_128);
var sum2 = IntVector.zero(INT_SPECIES_128);
var sum3 = IntVector.zero(INT_SPECIES_128);
int limit = ByteVector.SPECIES_128.loopBound(length);
- for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) {
- var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ for (; i < limit; i += ByteVector.SPECIES_128.length()) {
+ var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN);
var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts();
var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts();
var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts();
@@ -181,23 +186,22 @@ private long quantizeScore128(byte[] q) throws IOException {
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
// process scalar tail
- in.seek(offset);
for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
- final long value = in.readLong();
+ final long value = memorySegment.get(LAYOUT_LE_LONG, i);
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
}
for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
- final int value = in.readInt();
+ final int value = memorySegment.get(LAYOUT_LE_INT, i);
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
}
for (; i < length; i++) {
- int dValue = in.readByte() & 0xFF;
+ final int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF;
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
@@ -240,24 +244,29 @@ public boolean quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOE
}
private void nativeQuantizeScoreBulk(MemorySegment querySegment, int count, MemorySegment scoresSegment) throws IOException {
- long initialOffset = in.getFilePointer();
var datasetLengthInBytes = (long) length * count;
- MemorySegment datasetSegment = memorySegment.asSlice(initialOffset, datasetLengthInBytes);
-
- dotProductD1Q4Bulk(datasetSegment, querySegment, length, count, scoresSegment);
-
- in.skipBytes(datasetLengthInBytes);
+ IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, datasetSegment -> {
+ dotProductD1Q4Bulk(datasetSegment, querySegment, length, count, scoresSegment);
+ return null;
+ });
}
private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException {
+ var datasetLengthInBytes = (long) length * count;
+ IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, segment -> {
+ quantizeScore128BulkImpl(q, segment, length, count, scores);
+ return null;
+ });
+ }
+
+ private static void quantizeScore128BulkImpl(byte[] q, MemorySegment memorySegment, int length, int count, float[] scores) {
+ int offset = 0;
for (int iter = 0; iter < count; iter++) {
long subRet0 = 0;
long subRet1 = 0;
long subRet2 = 0;
long subRet3 = 0;
int i = 0;
- long offset = in.getFilePointer();
-
var sum0 = IntVector.zero(INT_SPECIES_128);
var sum1 = IntVector.zero(INT_SPECIES_128);
var sum2 = IntVector.zero(INT_SPECIES_128);
@@ -279,23 +288,22 @@ private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IO
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
// process scalar tail
- in.seek(offset);
- for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
- final long value = in.readLong();
+ for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES, offset += Long.BYTES) {
+ final long value = memorySegment.get(LAYOUT_LE_LONG, offset);
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
}
- for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
- final int value = in.readInt();
+ for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES, offset += Integer.BYTES) {
+ final int value = memorySegment.get(LAYOUT_LE_INT, offset);
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
}
- for (; i < length; i++) {
- int dValue = in.readByte() & 0xFF;
+ for (; i < length; i++, offset++) {
+ final int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, offset) & 0xFF;
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
@@ -306,13 +314,21 @@ private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IO
}
private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IOException {
+ var datasetLengthInBytes = (long) length * count;
+ IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, segment -> {
+ quantizeScore256BulkImpl(q, segment, length, count, scores);
+ return null;
+ });
+ }
+
+ private static void quantizeScore256BulkImpl(byte[] q, MemorySegment memorySegment, int length, int count, float[] scores) {
+ int offset = 0;
for (int iter = 0; iter < count; iter++) {
long subRet0 = 0;
long subRet1 = 0;
long subRet2 = 0;
long subRet3 = 0;
int i = 0;
- long offset = in.getFilePointer();
if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
int limit = ByteVector.SPECIES_256.loopBound(length);
var sum0 = LongVector.zero(LONG_SPECIES_256);
@@ -359,23 +375,22 @@ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IO
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
}
// process scalar tail
- in.seek(offset);
- for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
- final long value = in.readLong();
+ for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES, offset += Long.BYTES) {
+ final long value = memorySegment.get(LAYOUT_LE_LONG, offset);
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
}
- for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
- final int value = in.readInt();
+ for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES, offset += Integer.BYTES) {
+ final int value = memorySegment.get(LAYOUT_LE_INT, offset);
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
}
- for (; i < length; i++) {
- int dValue = in.readByte() & 0xFF;
+ for (; i < length; i++, offset++) {
+ final int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, offset) & 0xFF;
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
@@ -476,24 +491,25 @@ private float nativeApplyCorrectionsBulk(
MemorySegment scoresSegment,
int bulkSize
) throws IOException {
- long offset = in.getFilePointer();
-
- final float maxScore = ScoreCorrections.nativeApplyCorrectionsBulk(
- similarityFunction,
- memorySegment.asSlice(offset),
- bulkSize,
- dimensions,
- queryLowerInterval,
- queryUpperInterval,
- queryComponentSum,
- queryAdditionalCorrection,
- FOUR_BIT_SCALE,
- ONE_BIT_SCALE,
- centroidDp,
- scoresSegment
+ return IndexInputUtils.withSlice(
+ in,
+ 16L * bulkSize,
+ this::getScratch,
+ seg -> ScoreCorrections.nativeApplyCorrectionsBulk(
+ similarityFunction,
+ seg,
+ bulkSize,
+ dimensions,
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum,
+ queryAdditionalCorrection,
+ FOUR_BIT_SCALE,
+ ONE_BIT_SCALE,
+ centroidDp,
+ scoresSegment
+ )
);
- in.seek(offset + 16L * bulkSize);
- return maxScore;
}
private float applyCorrections128Bulk(
@@ -506,31 +522,59 @@ private float applyCorrections128Bulk(
float[] scores,
int bulkSize
) throws IOException {
+ return IndexInputUtils.withSlice(
+ in,
+ 16L * bulkSize,
+ this::getScratch,
+ seg -> applyCorrections128BulkImpl(
+ seg,
+ queryAdditionalCorrection,
+ similarityFunction,
+ centroidDp,
+ scores,
+ bulkSize,
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum
+ )
+ );
+ }
+
+ private float applyCorrections128BulkImpl(
+ MemorySegment memorySegment,
+ float queryAdditionalCorrection,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores,
+ int bulkSize,
+ float queryLowerInterval,
+ float queryUpperInterval,
+ int queryComponentSum
+ ) {
int limit = FLOAT_SPECIES_128.loopBound(bulkSize);
int i = 0;
- long offset = in.getFilePointer();
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
float y1 = queryComponentSum;
float maxScore = Float.NEGATIVE_INFINITY;
for (; i < limit; i += FLOAT_SPECIES_128.length()) {
- var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
+ var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var lx = FloatVector.fromMemorySegment(
FLOAT_SPECIES_128,
memorySegment,
- offset + 4L * bulkSize + i * Float.BYTES,
+ 4L * bulkSize + i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN
).sub(ax);
var targetComponentSums = IntVector.fromMemorySegment(
INT_SPECIES_128,
memorySegment,
- offset + 8L * bulkSize + i * Integer.BYTES,
+ 8L * bulkSize + i * Integer.BYTES,
ByteOrder.LITTLE_ENDIAN
).convert(VectorOperators.I2F, 0);
var additionalCorrections = FloatVector.fromMemorySegment(
FLOAT_SPECIES_128,
memorySegment,
- offset + 12L * bulkSize + i * Float.BYTES,
+ 12L * bulkSize + i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN
);
var qcDist = FloatVector.fromArray(FLOAT_SPECIES_128, scores, i);
@@ -568,6 +612,7 @@ private float applyCorrections128Bulk(
}
if (limit < bulkSize) {
maxScore = applyCorrectionsIndividually(
+ memorySegment,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
@@ -575,14 +620,13 @@ private float applyCorrections128Bulk(
scores,
bulkSize,
limit,
- offset,
+ i,
ay,
ly,
y1,
maxScore
);
}
- in.seek(offset + 16L * bulkSize);
return maxScore;
}
@@ -596,31 +640,59 @@ private float applyCorrections256Bulk(
float[] scores,
int bulkSize
) throws IOException {
+ return IndexInputUtils.withSlice(
+ in,
+ 16L * bulkSize,
+ this::getScratch,
+ memorySegment -> applyCorrections256BulkImpl(
+ memorySegment,
+ queryAdditionalCorrection,
+ similarityFunction,
+ centroidDp,
+ scores,
+ bulkSize,
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum
+ )
+ );
+ }
+
+ private float applyCorrections256BulkImpl(
+ MemorySegment memorySegment,
+ float queryAdditionalCorrection,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores,
+ int bulkSize,
+ float queryLowerInterval,
+ float queryUpperInterval,
+ int queryComponentSum
+ ) {
int limit = FLOAT_SPECIES_256.loopBound(bulkSize);
int i = 0;
- long offset = in.getFilePointer();
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
float y1 = queryComponentSum;
float maxScore = Float.NEGATIVE_INFINITY;
for (; i < limit; i += FLOAT_SPECIES_256.length()) {
- var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
+ var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var lx = FloatVector.fromMemorySegment(
FLOAT_SPECIES_256,
memorySegment,
- offset + 4L * bulkSize + i * Float.BYTES,
+ 4L * bulkSize + i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN
).sub(ax);
var targetComponentSums = IntVector.fromMemorySegment(
INT_SPECIES_256,
memorySegment,
- offset + 8L * bulkSize + i * Integer.BYTES,
+ 8L * bulkSize + i * Integer.BYTES,
ByteOrder.LITTLE_ENDIAN
).convert(VectorOperators.I2F, 0);
var additionalCorrections = FloatVector.fromMemorySegment(
FLOAT_SPECIES_256,
memorySegment,
- offset + 12L * bulkSize + i * Float.BYTES,
+ 12L * bulkSize + i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN
);
var qcDist = FloatVector.fromArray(FLOAT_SPECIES_256, scores, i);
@@ -658,6 +730,7 @@ private float applyCorrections256Bulk(
}
if (limit < bulkSize) {
maxScore = applyCorrectionsIndividually(
+ memorySegment,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
@@ -665,14 +738,13 @@ private float applyCorrections256Bulk(
scores,
bulkSize,
limit,
- offset,
+ i,
ay,
ly,
y1,
maxScore
);
}
- in.seek(offset + 16L * bulkSize);
return maxScore;
}
}
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSD7Q7ESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSD7Q7ESNextOSQVectorsScorer.java
index 8ab4c52b1272e..604df54c78019 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSD7Q7ESNextOSQVectorsScorer.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSD7Q7ESNextOSQVectorsScorer.java
@@ -20,9 +20,9 @@ final class MSD7Q7ESNextOSQVectorsScorer extends MemorySegmentESNextOSQVectorsSc
private final MemorySegmentES92Int7VectorsScorer int7Scorer;
- MSD7Q7ESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize, MemorySegment memorySegment) {
- super(in, dimensions, dataLength, bulkSize, memorySegment);
- this.int7Scorer = new MemorySegmentES92Int7VectorsScorer(in, dimensions, bulkSize, memorySegment);
+ MSD7Q7ESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize) {
+ super(in, dimensions, dataLength, bulkSize);
+ this.int7Scorer = new MemorySegmentES92Int7VectorsScorer(in, dimensions, bulkSize);
}
@Override
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSDibitToInt4ESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSDibitToInt4ESNextOSQVectorsScorer.java
index 57e48f83f1661..ce7ab96c936ee 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSDibitToInt4ESNextOSQVectorsScorer.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSDibitToInt4ESNextOSQVectorsScorer.java
@@ -18,6 +18,7 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.simdvec.internal.IndexInputUtils;
import java.io.IOException;
import java.lang.foreign.Arena;
@@ -33,8 +34,8 @@
/** Panamized scorer for quantized vectors stored as a {@link MemorySegment}. */
final class MSDibitToInt4ESNextOSQVectorsScorer extends MemorySegmentESNextOSQVectorsScorer.MemorySegmentScorer {
- MSDibitToInt4ESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize, MemorySegment memorySegment) {
- super(in, dimensions, dataLength, bulkSize, memorySegment);
+ MSDibitToInt4ESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize) {
+ super(in, dimensions, dataLength, bulkSize);
}
@Override
@@ -56,9 +57,10 @@ public long quantizeScore(byte[] q) throws IOException {
}
private long nativeQuantizeScore(byte[] q) throws IOException {
- long offset = in.getFilePointer();
- var datasetMemorySegment = memorySegment.asSlice(offset, length);
+ return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> nativeQuantizeScoreImpl(q, segment, length));
+ }
+ private static long nativeQuantizeScoreImpl(byte[] q, MemorySegment datasetMemorySegment, int length) {
final long qScore;
if (SUPPORTS_HEAP_SEGMENTS) {
var queryMemorySegment = MemorySegment.ofArray(q);
@@ -70,7 +72,6 @@ private long nativeQuantizeScore(byte[] q) throws IOException {
qScore = dotProductD2Q4(datasetMemorySegment, queryMemorySegment, length);
}
}
- in.skipBytes(length);
return qScore;
}
@@ -87,25 +88,28 @@ private long quantizeScore128DibitToInt4(byte[] q) throws IOException {
}
private long quantizeScore256(byte[] q) throws IOException {
+ int size = length / 2;
+ return IndexInputUtils.withSlice(in, size, this::getScratch, segment -> quantizeScore256Impl(q, segment, size));
+ }
+
+ private static long quantizeScore256Impl(byte[] q, MemorySegment memorySegment, int size) {
long subRet0 = 0;
long subRet1 = 0;
long subRet2 = 0;
long subRet3 = 0;
int i = 0;
- long offset = in.getFilePointer();
- int size = length / 2;
if (size >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
int limit = ByteVector.SPECIES_256.loopBound(size);
var sum0 = LongVector.zero(LONG_SPECIES_256);
var sum1 = LongVector.zero(LONG_SPECIES_256);
var sum2 = LongVector.zero(LONG_SPECIES_256);
var sum3 = LongVector.zero(LONG_SPECIES_256);
- for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) {
+ for (; i < limit; i += ByteVector.SPECIES_256.length()) {
var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs();
var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size).reinterpretAsLongs();
var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size * 2).reinterpretAsLongs();
var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size * 3).reinterpretAsLongs();
- var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, i, ByteOrder.LITTLE_ENDIAN);
sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
@@ -123,12 +127,12 @@ private long quantizeScore256(byte[] q) throws IOException {
var sum2 = LongVector.zero(LONG_SPECIES_128);
var sum3 = LongVector.zero(LONG_SPECIES_128);
int limit = ByteVector.SPECIES_128.loopBound(size);
- for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) {
+ for (; i < limit; i += ByteVector.SPECIES_128.length()) {
var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs();
var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size).reinterpretAsLongs();
var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 2).reinterpretAsLongs();
var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 3).reinterpretAsLongs();
- var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN);
sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
@@ -140,23 +144,22 @@ private long quantizeScore256(byte[] q) throws IOException {
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
}
// process scalar tail
- in.seek(offset);
for (final int upperBound = size & -Long.BYTES; i < upperBound; i += Long.BYTES) {
- final long value = in.readLong();
+ final long value = memorySegment.get(LAYOUT_LE_LONG, i);
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + size) & value);
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * size) & value);
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * size) & value);
}
for (final int upperBound = size & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
- final int value = in.readInt();
+ final int value = memorySegment.get(LAYOUT_LE_INT, i);
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + size) & value);
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * size) & value);
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * size) & value);
}
for (; i < size; i++) {
- int dValue = in.readByte() & 0xFF;
+ int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF;
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
subRet1 += Integer.bitCount((q[i + size] & dValue) & 0xFF);
subRet2 += Integer.bitCount((q[i + 2 * size] & dValue) & 0xFF);
@@ -166,21 +169,24 @@ private long quantizeScore256(byte[] q) throws IOException {
}
private long quantizeScore128(byte[] q) throws IOException {
+ int size = length / 2;
+ return IndexInputUtils.withSlice(in, size, this::getScratch, segment -> quantizeScore128Impl(q, segment, size));
+ }
+
+ private static long quantizeScore128Impl(byte[] q, MemorySegment memorySegment, int size) {
long subRet0 = 0;
long subRet1 = 0;
long subRet2 = 0;
long subRet3 = 0;
int i = 0;
- long offset = in.getFilePointer();
var sum0 = IntVector.zero(INT_SPECIES_128);
var sum1 = IntVector.zero(INT_SPECIES_128);
var sum2 = IntVector.zero(INT_SPECIES_128);
var sum3 = IntVector.zero(INT_SPECIES_128);
- int size = length / 2;
int limit = ByteVector.SPECIES_128.loopBound(size);
- for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) {
- var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ for (; i < limit; i += ByteVector.SPECIES_128.length()) {
+ var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN);
var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts();
var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size).reinterpretAsInts();
var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 2).reinterpretAsInts();
@@ -195,23 +201,22 @@ private long quantizeScore128(byte[] q) throws IOException {
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
// process scalar tail
- in.seek(offset);
for (final int upperBound = size & -Long.BYTES; i < upperBound; i += Long.BYTES) {
- final long value = in.readLong();
+ final long value = memorySegment.get(LAYOUT_LE_LONG, i);
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + size) & value);
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * size) & value);
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * size) & value);
}
for (final int upperBound = size & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
- final int value = in.readInt();
+ final int value = memorySegment.get(LAYOUT_LE_INT, i);
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + size) & value);
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * size) & value);
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * size) & value);
}
for (; i < size; i++) {
- int dValue = in.readByte() & 0xFF;
+ int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF;
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
subRet1 += Integer.bitCount((q[i + size] & dValue) & 0xFF);
subRet2 += Integer.bitCount((q[i + 2 * size] & dValue) & 0xFF);
@@ -254,13 +259,11 @@ public boolean quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOE
}
private void nativeQuantizeScoreBulk(MemorySegment queryMemorySegment, int count, MemorySegment scoresSegment) throws IOException {
- long initialOffset = in.getFilePointer();
var datasetLengthInBytes = (long) length * count;
- MemorySegment datasetSegment = memorySegment.asSlice(initialOffset, datasetLengthInBytes);
-
- dotProductD2Q4Bulk(datasetSegment, queryMemorySegment, length, count, scoresSegment);
-
- in.skipBytes(datasetLengthInBytes);
+ IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, datasetSegment -> {
+ dotProductD2Q4Bulk(datasetSegment, queryMemorySegment, length, count, scoresSegment);
+ return null;
+ });
}
private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException {
@@ -366,24 +369,25 @@ private float nativeApplyCorrectionsBulk(
MemorySegment scoresSegment,
int bulkSize
) throws IOException {
- long offset = in.getFilePointer();
-
- final float maxScore = ScoreCorrections.nativeApplyCorrectionsBulk(
- similarityFunction,
- memorySegment.asSlice(offset),
- bulkSize,
- dimensions,
- queryLowerInterval,
- queryUpperInterval,
- queryComponentSum,
- queryAdditionalCorrection,
- FOUR_BIT_SCALE,
- TWO_BIT_SCALE,
- centroidDp,
- scoresSegment
+ return IndexInputUtils.withSlice(
+ in,
+ 16L * bulkSize,
+ this::getScratch,
+ seg -> ScoreCorrections.nativeApplyCorrectionsBulk(
+ similarityFunction,
+ seg,
+ bulkSize,
+ dimensions,
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum,
+ queryAdditionalCorrection,
+ FOUR_BIT_SCALE,
+ TWO_BIT_SCALE,
+ centroidDp,
+ scoresSegment
+ )
);
- in.seek(offset + 16L * bulkSize);
- return maxScore;
}
private float applyCorrections128Bulk(
@@ -396,31 +400,59 @@ private float applyCorrections128Bulk(
float[] scores,
int bulkSize
) throws IOException {
+ return IndexInputUtils.withSlice(
+ in,
+ 16L * bulkSize,
+ this::getScratch,
+ memorySegment -> applyCorrections128BulkImpl(
+ memorySegment,
+ queryAdditionalCorrection,
+ similarityFunction,
+ centroidDp,
+ scores,
+ bulkSize,
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum
+ )
+ );
+ }
+
+ private float applyCorrections128BulkImpl(
+ MemorySegment memorySegment,
+ float queryAdditionalCorrection,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores,
+ int bulkSize,
+ float queryLowerInterval,
+ float queryUpperInterval,
+ int queryComponentSum
+ ) {
int limit = FLOAT_SPECIES_128.loopBound(bulkSize);
int i = 0;
- long offset = in.getFilePointer();
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
float y1 = queryComponentSum;
float maxScore = Float.NEGATIVE_INFINITY;
for (; i < limit; i += FLOAT_SPECIES_128.length()) {
- var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
+ var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var lx = FloatVector.fromMemorySegment(
FLOAT_SPECIES_128,
memorySegment,
- offset + 4L * bulkSize + i * Float.BYTES,
+ 4L * bulkSize + i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN
).sub(ax).mul(TWO_BIT_SCALE);
var targetComponentSums = IntVector.fromMemorySegment(
INT_SPECIES_128,
memorySegment,
- offset + 8L * bulkSize + i * Integer.BYTES,
+ 8L * bulkSize + i * Integer.BYTES,
ByteOrder.LITTLE_ENDIAN
).convert(VectorOperators.I2F, 0);
var additionalCorrections = FloatVector.fromMemorySegment(
FLOAT_SPECIES_128,
memorySegment,
- offset + 12L * bulkSize + i * Float.BYTES,
+ 12L * bulkSize + i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN
);
var qcDist = FloatVector.fromArray(FLOAT_SPECIES_128, scores, i);
@@ -457,8 +489,8 @@ private float applyCorrections128Bulk(
}
}
if (limit < bulkSize) {
- // missing vectors to score
maxScore = applyCorrectionsIndividually(
+ memorySegment,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
@@ -466,14 +498,13 @@ private float applyCorrections128Bulk(
scores,
bulkSize,
limit,
- offset,
+ 0,
ay,
ly,
y1,
maxScore
);
}
- in.seek(offset + 16L * bulkSize);
return maxScore;
}
@@ -487,31 +518,59 @@ private float applyCorrections256Bulk(
float[] scores,
int bulkSize
) throws IOException {
+ return IndexInputUtils.withSlice(
+ in,
+ 16L * bulkSize,
+ this::getScratch,
+ memorySegment -> applyCorrections256BulkImpl(
+ memorySegment,
+ queryAdditionalCorrection,
+ similarityFunction,
+ centroidDp,
+ scores,
+ bulkSize,
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum
+ )
+ );
+ }
+
+ private float applyCorrections256BulkImpl(
+ MemorySegment memorySegment,
+ float queryAdditionalCorrection,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores,
+ int bulkSize,
+ float queryLowerInterval,
+ float queryUpperInterval,
+ int queryComponentSum
+ ) {
int limit = FLOAT_SPECIES_256.loopBound(bulkSize);
int i = 0;
- long offset = in.getFilePointer();
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
float y1 = queryComponentSum;
float maxScore = Float.NEGATIVE_INFINITY;
for (; i < limit; i += FLOAT_SPECIES_256.length()) {
- var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
+ var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var lx = FloatVector.fromMemorySegment(
FLOAT_SPECIES_256,
memorySegment,
- offset + 4L * bulkSize + i * Float.BYTES,
+ 4L * bulkSize + i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN
).sub(ax).mul(TWO_BIT_SCALE);
var targetComponentSums = IntVector.fromMemorySegment(
INT_SPECIES_256,
memorySegment,
- offset + 8L * bulkSize + i * Integer.BYTES,
+ 8L * bulkSize + i * Integer.BYTES,
ByteOrder.LITTLE_ENDIAN
).convert(VectorOperators.I2F, 0);
var additionalCorrections = FloatVector.fromMemorySegment(
FLOAT_SPECIES_256,
memorySegment,
- offset + 12L * bulkSize + i * Float.BYTES,
+ 12L * bulkSize + i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN
);
var qcDist = FloatVector.fromArray(FLOAT_SPECIES_256, scores, i);
@@ -550,6 +609,7 @@ private float applyCorrections256Bulk(
if (limit < bulkSize) {
// missing vectors to score
maxScore = applyCorrectionsIndividually(
+ memorySegment,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
@@ -557,14 +617,13 @@ private float applyCorrections256Bulk(
scores,
bulkSize,
limit,
- offset,
+ 0,
ay,
ly,
y1,
maxScore
);
}
- in.seek(offset + 16L * bulkSize);
return maxScore;
}
}
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java
index 948596a9c9a16..6b0fa715263f1 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java
@@ -18,6 +18,7 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.simdvec.internal.IndexInputUtils;
import java.io.IOException;
import java.lang.foreign.Arena;
@@ -33,8 +34,8 @@
/** Panamized scorer for quantized vectors stored as a {@link MemorySegment}. */
final class MSInt4SymmetricESNextOSQVectorsScorer extends MemorySegmentESNextOSQVectorsScorer.MemorySegmentScorer {
- MSInt4SymmetricESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize, MemorySegment memorySegment) {
- super(in, dimensions, dataLength, bulkSize, memorySegment);
+ MSInt4SymmetricESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, int bulkSize) {
+ super(in, dimensions, dataLength, bulkSize);
}
@Override
@@ -56,9 +57,10 @@ public long quantizeScore(byte[] q) throws IOException {
}
private long nativeQuantizeScore(byte[] q) throws IOException {
- long offset = in.getFilePointer();
- var datasetMemorySegment = memorySegment.asSlice(offset, length);
+ return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> nativeQuantizeScoreImpl(q, segment, length));
+ }
+ private static long nativeQuantizeScoreImpl(byte[] q, MemorySegment datasetMemorySegment, int length) {
final long qScore;
if (SUPPORTS_HEAP_SEGMENTS) {
var queryMemorySegment = MemorySegment.ofArray(q);
@@ -70,7 +72,6 @@ private long nativeQuantizeScore(byte[] q) throws IOException {
qScore = dotProductD4Q4(datasetMemorySegment, queryMemorySegment, length);
}
}
- in.skipBytes(length);
return qScore;
}
@@ -91,25 +92,28 @@ private long quantizeScoreSymmetric256(byte[] q) throws IOException {
}
private long quantizeScore256(byte[] q) throws IOException {
+ int size = length / 4;
+ return IndexInputUtils.withSlice(in, size, this::getScratch, segment -> quantizeScore256Impl(q, segment, size));
+ }
+
+ private static long quantizeScore256Impl(byte[] q, MemorySegment memorySegment, int size) {
long subRet0 = 0;
long subRet1 = 0;
long subRet2 = 0;
long subRet3 = 0;
int i = 0;
- long offset = in.getFilePointer();
- int size = length / 4;
if (size >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
int limit = ByteVector.SPECIES_256.loopBound(size);
var sum0 = LongVector.zero(LONG_SPECIES_256);
var sum1 = LongVector.zero(LONG_SPECIES_256);
var sum2 = LongVector.zero(LONG_SPECIES_256);
var sum3 = LongVector.zero(LONG_SPECIES_256);
- for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) {
+ for (; i < limit; i += ByteVector.SPECIES_256.length()) {
var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs();
var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size).reinterpretAsLongs();
var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size * 2).reinterpretAsLongs();
var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size * 3).reinterpretAsLongs();
- var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, i, ByteOrder.LITTLE_ENDIAN);
sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
@@ -127,12 +131,12 @@ private long quantizeScore256(byte[] q) throws IOException {
var sum2 = LongVector.zero(LONG_SPECIES_128);
var sum3 = LongVector.zero(LONG_SPECIES_128);
int limit = ByteVector.SPECIES_128.loopBound(size);
- for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) {
+ for (; i < limit; i += ByteVector.SPECIES_128.length()) {
var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs();
var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size).reinterpretAsLongs();
var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 2).reinterpretAsLongs();
var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 3).reinterpretAsLongs();
- var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN);
sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
@@ -144,23 +148,22 @@ private long quantizeScore256(byte[] q) throws IOException {
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
}
// process scalar tail
- in.seek(offset);
for (final int upperBound = size & -Long.BYTES; i < upperBound; i += Long.BYTES) {
- final long value = in.readLong();
+ final long value = memorySegment.get(LAYOUT_LE_LONG, i);
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + size) & value);
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * size) & value);
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * size) & value);
}
for (final int upperBound = size & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
- final int value = in.readInt();
+ final int value = memorySegment.get(LAYOUT_LE_INT, i);
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + size) & value);
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * size) & value);
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * size) & value);
}
for (; i < size; i++) {
- int dValue = in.readByte() & 0xFF;
+ int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF;
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
subRet1 += Integer.bitCount((q[i + size] & dValue) & 0xFF);
subRet2 += Integer.bitCount((q[i + 2 * size] & dValue) & 0xFF);
@@ -170,21 +173,24 @@ private long quantizeScore256(byte[] q) throws IOException {
}
private long quantizeScore128(byte[] q) throws IOException {
+ int size = length / 4;
+ return IndexInputUtils.withSlice(in, size, this::getScratch, segment -> quantizeScore128Impl(q, segment, size));
+ }
+
+ private static long quantizeScore128Impl(byte[] q, MemorySegment memorySegment, int size) {
long subRet0 = 0;
long subRet1 = 0;
long subRet2 = 0;
long subRet3 = 0;
int i = 0;
- long offset = in.getFilePointer();
var sum0 = IntVector.zero(INT_SPECIES_128);
var sum1 = IntVector.zero(INT_SPECIES_128);
var sum2 = IntVector.zero(INT_SPECIES_128);
var sum3 = IntVector.zero(INT_SPECIES_128);
- int size = length / 4;
int limit = ByteVector.SPECIES_128.loopBound(size);
- for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) {
- var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ for (; i < limit; i += ByteVector.SPECIES_128.length()) {
+ var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN);
var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts();
var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size).reinterpretAsInts();
var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 2).reinterpretAsInts();
@@ -199,23 +205,22 @@ private long quantizeScore128(byte[] q) throws IOException {
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
// process scalar tail
- in.seek(offset);
for (final int upperBound = size & -Long.BYTES; i < upperBound; i += Long.BYTES) {
- final long value = in.readLong();
+ final long value = memorySegment.get(LAYOUT_LE_LONG, i);
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + size) & value);
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * size) & value);
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * size) & value);
}
for (final int upperBound = size & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
- final int value = in.readInt();
+ final int value = memorySegment.get(LAYOUT_LE_INT, i);
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + size) & value);
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * size) & value);
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * size) & value);
}
for (; i < size; i++) {
- int dValue = in.readByte() & 0xFF;
+ int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF;
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
subRet1 += Integer.bitCount((q[i + size] & dValue) & 0xFF);
subRet2 += Integer.bitCount((q[i + 2 * size] & dValue) & 0xFF);
@@ -259,13 +264,11 @@ public boolean quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOE
}
private void nativeQuantizeScoreBulk(MemorySegment queryMemorySegment, int count, MemorySegment scoresSegment) throws IOException {
- long initialOffset = in.getFilePointer();
var datasetLengthInBytes = (long) length * count;
- MemorySegment datasetSegment = memorySegment.asSlice(initialOffset, datasetLengthInBytes);
-
- dotProductD4Q4Bulk(datasetSegment, queryMemorySegment, length, count, scoresSegment);
-
- in.skipBytes(datasetLengthInBytes);
+ IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, datasetSegment -> {
+ dotProductD4Q4Bulk(datasetSegment, queryMemorySegment, length, count, scoresSegment);
+ return null;
+ });
}
private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException {
@@ -371,24 +374,25 @@ private float nativeApplyCorrectionsBulk(
MemorySegment scoresSegment,
int bulkSize
) throws IOException {
- long offset = in.getFilePointer();
-
- final float maxScore = ScoreCorrections.nativeApplyCorrectionsBulk(
- similarityFunction,
- memorySegment.asSlice(offset),
- bulkSize,
- dimensions,
- queryLowerInterval,
- queryUpperInterval,
- queryComponentSum,
- queryAdditionalCorrection,
- FOUR_BIT_SCALE,
- FOUR_BIT_SCALE,
- centroidDp,
- scoresSegment
+ return IndexInputUtils.withSlice(
+ in,
+ 16L * bulkSize,
+ this::getScratch,
+ memorySegment -> ScoreCorrections.nativeApplyCorrectionsBulk(
+ similarityFunction,
+ memorySegment,
+ bulkSize,
+ dimensions,
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum,
+ queryAdditionalCorrection,
+ FOUR_BIT_SCALE,
+ FOUR_BIT_SCALE,
+ centroidDp,
+ scoresSegment
+ )
);
- in.seek(offset + 16L * bulkSize);
- return maxScore;
}
private float applyCorrections128Bulk(
@@ -401,36 +405,62 @@ private float applyCorrections128Bulk(
float[] scores,
int bulkSize
) throws IOException {
+ return IndexInputUtils.withSlice(
+ in,
+ 16L * bulkSize,
+ this::getScratch,
+ memorySegment -> applyCorrections128BulkImpl(
+ memorySegment,
+ queryAdditionalCorrection,
+ similarityFunction,
+ centroidDp,
+ scores,
+ bulkSize,
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum
+ )
+ );
+ }
+
+ private float applyCorrections128BulkImpl(
+ MemorySegment memorySegment,
+ float queryAdditionalCorrection,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores,
+ int bulkSize,
+ float queryLowerInterval,
+ float queryUpperInterval,
+ int queryComponentSum
+ ) {
int limit = FLOAT_SPECIES_128.loopBound(bulkSize);
int i = 0;
- long offset = in.getFilePointer();
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
float y1 = queryComponentSum;
float maxScore = Float.NEGATIVE_INFINITY;
for (; i < limit; i += FLOAT_SPECIES_128.length()) {
- var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
+ var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var lx = FloatVector.fromMemorySegment(
FLOAT_SPECIES_128,
memorySegment,
- offset + 4L * bulkSize + i * Float.BYTES,
+ 4L * bulkSize + i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN
).sub(ax).mul(FOUR_BIT_SCALE);
var targetComponentSums = IntVector.fromMemorySegment(
INT_SPECIES_128,
memorySegment,
- offset + 8L * bulkSize + i * Integer.BYTES,
+ 8L * bulkSize + i * Integer.BYTES,
ByteOrder.LITTLE_ENDIAN
).convert(VectorOperators.I2F, 0);
var additionalCorrections = FloatVector.fromMemorySegment(
FLOAT_SPECIES_128,
memorySegment,
- offset + 12L * bulkSize + i * Float.BYTES,
+ 12L * bulkSize + i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN
);
var qcDist = FloatVector.fromArray(FLOAT_SPECIES_128, scores, i);
- // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly *
- // qcDist;
var res1 = ax.mul(ay).mul(dimensions);
var res2 = lx.mul(ay).mul(targetComponentSums);
var res3 = ax.mul(ly).mul(y1);
@@ -464,6 +494,7 @@ private float applyCorrections128Bulk(
if (limit < bulkSize) {
// missing vectors to score
maxScore = applyCorrectionsIndividually(
+ memorySegment,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
@@ -471,14 +502,13 @@ private float applyCorrections128Bulk(
scores,
bulkSize,
limit,
- offset,
+ 0,
ay,
ly,
y1,
maxScore
);
}
- in.seek(offset + 16L * bulkSize);
return maxScore;
}
@@ -492,31 +522,59 @@ private float applyCorrections256Bulk(
float[] scores,
int bulkSize
) throws IOException {
+ return IndexInputUtils.withSlice(
+ in,
+ 16L * bulkSize,
+ this::getScratch,
+ memorySegment -> applyCorrections256BulkImpl(
+ memorySegment,
+ queryAdditionalCorrection,
+ similarityFunction,
+ centroidDp,
+ scores,
+ bulkSize,
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum
+ )
+ );
+ }
+
+ private float applyCorrections256BulkImpl(
+ MemorySegment memorySegment,
+ float queryAdditionalCorrection,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores,
+ int bulkSize,
+ float queryLowerInterval,
+ float queryUpperInterval,
+ int queryComponentSum
+ ) {
int limit = FLOAT_SPECIES_256.loopBound(bulkSize);
int i = 0;
- long offset = in.getFilePointer();
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
float y1 = queryComponentSum;
float maxScore = Float.NEGATIVE_INFINITY;
for (; i < limit; i += FLOAT_SPECIES_256.length()) {
- var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
+ var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var lx = FloatVector.fromMemorySegment(
FLOAT_SPECIES_256,
memorySegment,
- offset + 4L * bulkSize + i * Float.BYTES,
+ 4L * bulkSize + i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN
).sub(ax).mul(FOUR_BIT_SCALE);
var targetComponentSums = IntVector.fromMemorySegment(
INT_SPECIES_256,
memorySegment,
- offset + 8L * bulkSize + i * Integer.BYTES,
+ 8L * bulkSize + i * Integer.BYTES,
ByteOrder.LITTLE_ENDIAN
).convert(VectorOperators.I2F, 0);
var additionalCorrections = FloatVector.fromMemorySegment(
FLOAT_SPECIES_256,
memorySegment,
- offset + 12L * bulkSize + i * Float.BYTES,
+ 12L * bulkSize + i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN
);
var qcDist = FloatVector.fromArray(FLOAT_SPECIES_256, scores, i);
@@ -555,6 +613,7 @@ private float applyCorrections256Bulk(
if (limit < bulkSize) {
// missing vectors to score
maxScore = applyCorrectionsIndividually(
+ memorySegment,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
@@ -562,14 +621,13 @@ private float applyCorrections256Bulk(
scores,
bulkSize,
limit,
- offset,
+ 0,
ay,
ly,
y1,
maxScore
);
}
- in.seek(offset + 16L * bulkSize);
return maxScore;
}
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java
index 6d923ff40f7f9..cc0f4f52a56d7 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java
@@ -21,9 +21,11 @@
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
+import org.elasticsearch.simdvec.internal.IndexInputUtils;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
+import java.lang.foreign.ValueLayout;
import java.nio.ByteOrder;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
@@ -46,11 +48,20 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
private static final VectorSpecies FLOAT_SPECIES_128 = FloatVector.SPECIES_128;
private static final VectorSpecies FLOAT_SPECIES_256 = FloatVector.SPECIES_256;
- private final MemorySegment memorySegment;
+ static final ValueLayout.OfLong LAYOUT_LE_LONG = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
+ static final ValueLayout.OfInt LAYOUT_LE_INT = ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
- public MemorySegmentES91OSQVectorsScorer(IndexInput in, int dimensions, int bulkSize, MemorySegment memorySegment) {
+ private byte[] scratch;
+
+ public MemorySegmentES91OSQVectorsScorer(IndexInput in, int dimensions, int bulkSize) {
super(in, dimensions, bulkSize);
- this.memorySegment = memorySegment;
+ }
+
+ private byte[] getScratch(int len) {
+ if (scratch == null || scratch.length < len) {
+ scratch = new byte[len];
+ }
+ return scratch;
}
@Override
@@ -68,24 +79,27 @@ public long quantizeScore(byte[] q) throws IOException {
}
private long quantizeScore256(byte[] q) throws IOException {
+ return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> quantizeScore256Impl(q, segment, length));
+ }
+
+ private static long quantizeScore256Impl(byte[] q, MemorySegment memorySegment, int length) {
long subRet0 = 0;
long subRet1 = 0;
long subRet2 = 0;
long subRet3 = 0;
int i = 0;
- long offset = in.getFilePointer();
if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
int limit = ByteVector.SPECIES_256.loopBound(length);
var sum0 = LongVector.zero(LONG_SPECIES_256);
var sum1 = LongVector.zero(LONG_SPECIES_256);
var sum2 = LongVector.zero(LONG_SPECIES_256);
var sum3 = LongVector.zero(LONG_SPECIES_256);
- for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) {
+ for (; i < limit; i += ByteVector.SPECIES_256.length()) {
var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs();
var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs();
var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs();
var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs();
- var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, i, ByteOrder.LITTLE_ENDIAN);
sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
@@ -103,12 +117,12 @@ private long quantizeScore256(byte[] q) throws IOException {
var sum2 = LongVector.zero(LONG_SPECIES_128);
var sum3 = LongVector.zero(LONG_SPECIES_128);
int limit = ByteVector.SPECIES_128.loopBound(length);
- for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) {
+ for (; i < limit; i += ByteVector.SPECIES_128.length()) {
var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs();
var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs();
var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs();
var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs();
- var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN);
sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
@@ -120,23 +134,22 @@ private long quantizeScore256(byte[] q) throws IOException {
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
}
// process scalar tail
- in.seek(offset);
for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
- final long value = in.readLong();
+ final long value = memorySegment.get(LAYOUT_LE_LONG, i);
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
}
for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
- final int value = in.readInt();
+ final int value = memorySegment.get(LAYOUT_LE_INT, i);
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
}
for (; i < length; i++) {
- int dValue = in.readByte() & 0xFF;
+ int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF;
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
@@ -146,20 +159,23 @@ private long quantizeScore256(byte[] q) throws IOException {
}
private long quantizeScore128(byte[] q) throws IOException {
+ return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> quantizeScore128Impl(q, segment, length));
+ }
+
+ private static long quantizeScore128Impl(byte[] q, MemorySegment memorySegment, int length) {
long subRet0 = 0;
long subRet1 = 0;
long subRet2 = 0;
long subRet3 = 0;
int i = 0;
- long offset = in.getFilePointer();
var sum0 = IntVector.zero(INT_SPECIES_128);
var sum1 = IntVector.zero(INT_SPECIES_128);
var sum2 = IntVector.zero(INT_SPECIES_128);
var sum3 = IntVector.zero(INT_SPECIES_128);
int limit = ByteVector.SPECIES_128.loopBound(length);
- for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) {
- var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ for (; i < limit; i += ByteVector.SPECIES_128.length()) {
+ var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, i, ByteOrder.LITTLE_ENDIAN);
var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts();
var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts();
var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts();
@@ -174,23 +190,22 @@ private long quantizeScore128(byte[] q) throws IOException {
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
// process scalar tail
- in.seek(offset);
for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
- final long value = in.readLong();
+ final long value = memorySegment.get(LAYOUT_LE_LONG, i);
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
}
for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
- final int value = in.readInt();
+ final int value = memorySegment.get(LAYOUT_LE_INT, i);
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
}
for (; i < length; i++) {
- int dValue = in.readByte() & 0xFF;
+ int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, i) & 0xFF;
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
@@ -216,13 +231,21 @@ public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOExce
}
private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException {
+ var datasetLengthInBytes = (long) length * count;
+ IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, segment -> {
+ quantizeScore128BulkImpl(q, count, scores, segment, length);
+ return null;
+ });
+ }
+
+ private static void quantizeScore128BulkImpl(byte[] q, int count, float[] scores, MemorySegment memorySegment, int length) {
+ long offset = 0L;
for (int iter = 0; iter < count; iter++) {
long subRet0 = 0;
long subRet1 = 0;
long subRet2 = 0;
long subRet3 = 0;
int i = 0;
- long offset = in.getFilePointer();
var sum0 = IntVector.zero(INT_SPECIES_128);
var sum1 = IntVector.zero(INT_SPECIES_128);
@@ -245,23 +268,22 @@ private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IO
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
// process scalar tail
- in.seek(offset);
- for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
- final long value = in.readLong();
+ for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES, offset += Long.BYTES) {
+ final long value = memorySegment.get(LAYOUT_LE_LONG, offset);
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
}
- for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
- final int value = in.readInt();
+ for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES, offset += Integer.BYTES) {
+ final int value = memorySegment.get(LAYOUT_LE_INT, offset);
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
}
- for (; i < length; i++) {
- int dValue = in.readByte() & 0xFF;
+ for (; i < length; i++, offset++) {
+ int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, offset) & 0xFF;
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
@@ -272,13 +294,21 @@ private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IO
}
private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IOException {
+ var datasetLengthInBytes = (long) length * count;
+ IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, segment -> {
+ quantizeScore256BulkImpl(q, count, scores, segment, length);
+ return null;
+ });
+ }
+
+ private static void quantizeScore256BulkImpl(byte[] q, int count, float[] scores, MemorySegment memorySegment, int length) {
+ long offset = 0L;
for (int iter = 0; iter < count; iter++) {
long subRet0 = 0;
long subRet1 = 0;
long subRet2 = 0;
long subRet3 = 0;
int i = 0;
- long offset = in.getFilePointer();
if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
int limit = ByteVector.SPECIES_256.loopBound(length);
var sum0 = LongVector.zero(LONG_SPECIES_256);
@@ -325,23 +355,22 @@ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IO
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
}
// process scalar tail
- in.seek(offset);
- for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
- final long value = in.readLong();
+ for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES, offset += Long.BYTES) {
+ final long value = memorySegment.get(LAYOUT_LE_LONG, offset);
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
}
- for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
- final int value = in.readInt();
+ for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES, offset += Integer.BYTES) {
+ final int value = memorySegment.get(LAYOUT_LE_INT, offset);
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
}
- for (; i < length; i++) {
- int dValue = in.readByte() & 0xFF;
+ for (; i < length; i++, offset++) {
+ int dValue = memorySegment.get(ValueLayout.JAVA_BYTE, offset) & 0xFF;
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
@@ -366,26 +395,38 @@ public float scoreBulk(
// 128 / 8 == 16
if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
- return score256Bulk(
- q,
- queryLowerInterval,
- queryUpperInterval,
- queryComponentSum,
- queryAdditionalCorrection,
- similarityFunction,
- centroidDp,
- scores
+ return IndexInputUtils.withSlice(
+ in,
+ (14L + length) * this.bulkSize,
+ this::getScratch,
+ segment -> score256Bulk(
+ segment,
+ q,
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum,
+ queryAdditionalCorrection,
+ similarityFunction,
+ centroidDp,
+ scores
+ )
);
} else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
- return score128Bulk(
- q,
- queryLowerInterval,
- queryUpperInterval,
- queryComponentSum,
- queryAdditionalCorrection,
- similarityFunction,
- centroidDp,
- scores
+ return IndexInputUtils.withSlice(
+ in,
+ (14L + length) * this.bulkSize,
+ this::getScratch,
+ segment -> score128Bulk(
+ segment,
+ q,
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum,
+ queryAdditionalCorrection,
+ similarityFunction,
+ centroidDp,
+ scores
+ )
);
}
}
@@ -402,6 +443,7 @@ public float scoreBulk(
}
private float score128Bulk(
+ MemorySegment memorySegment,
byte[] q,
float queryLowerInterval,
float queryUpperInterval,
@@ -410,11 +452,11 @@ private float score128Bulk(
VectorSimilarityFunction similarityFunction,
float centroidDp,
float[] scores
- ) throws IOException {
- quantizeScore128Bulk(q, this.bulkSize, scores);
+ ) {
+ quantizeScore128BulkImpl(q, this.bulkSize, scores, memorySegment, length);
int limit = FLOAT_SPECIES_128.loopBound(this.bulkSize);
int i = 0;
- long offset = in.getFilePointer();
+ long offset = (long) length * this.bulkSize;
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
float y1 = queryComponentSum;
@@ -472,11 +514,11 @@ private float score128Bulk(
}
}
}
- in.seek(offset + 14L * this.bulkSize);
return maxScore;
}
private float score256Bulk(
+ MemorySegment memorySegment,
byte[] q,
float queryLowerInterval,
float queryUpperInterval,
@@ -485,11 +527,11 @@ private float score256Bulk(
VectorSimilarityFunction similarityFunction,
float centroidDp,
float[] scores
- ) throws IOException {
- quantizeScore256Bulk(q, this.bulkSize, scores);
+ ) {
+ quantizeScore256BulkImpl(q, this.bulkSize, scores, memorySegment, length);
int limit = FLOAT_SPECIES_256.loopBound(this.bulkSize);
int i = 0;
- long offset = in.getFilePointer();
+ long offset = (long) length * this.bulkSize;
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
float y1 = queryComponentSum;
@@ -547,7 +589,6 @@ private float score256Bulk(
}
}
}
- in.seek(offset + 14L * this.bulkSize);
return maxScore;
}
}
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java
index 9fcf70ebb7c92..434c31004e5ff 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java
@@ -16,9 +16,12 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.core.DirectAccessInput;
import org.elasticsearch.nativeaccess.NativeAccess;
import org.elasticsearch.simdvec.ESNextOSQVectorsScorer;
+import org.elasticsearch.simdvec.internal.IndexInputUtils;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
@@ -39,18 +42,17 @@ public MemorySegmentESNextOSQVectorsScorer(
byte indexBits,
int dimensions,
int dataLength,
- int bulkSize,
- MemorySegment memorySegment
+ int bulkSize
) {
super(in, queryBits, indexBits, dimensions, dataLength);
if (queryBits == 4 && indexBits == 1) {
- this.scorer = new MSBitToInt4ESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize, memorySegment);
+ this.scorer = new MSBitToInt4ESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize);
} else if (queryBits == 4 && indexBits == 4) {
- this.scorer = new MSInt4SymmetricESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize, memorySegment);
+ this.scorer = new MSInt4SymmetricESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize);
} else if (queryBits == 4 && indexBits == 2) {
- this.scorer = new MSDibitToInt4ESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize, memorySegment);
+ this.scorer = new MSDibitToInt4ESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize);
} else if (queryBits == 7 && indexBits == 7) {
- this.scorer = new MSD7Q7ESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize, memorySegment);
+ this.scorer = new MSD7Q7ESNextOSQVectorsScorer(in, dimensions, dataLength, bulkSize);
} else {
throw new IllegalArgumentException("Unsupported query/index bits combination: " + queryBits + "/" + indexBits);
}
@@ -171,20 +173,48 @@ abstract static sealed class MemorySegmentScorer permits MSBitToInt4ESNextOSQVec
static final VectorSpecies FLOAT_SPECIES_128 = FloatVector.SPECIES_128;
static final VectorSpecies FLOAT_SPECIES_256 = FloatVector.SPECIES_256;
- protected final MemorySegment memorySegment;
+ static final ValueLayout.OfLong LAYOUT_LE_LONG = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
+ static final ValueLayout.OfInt LAYOUT_LE_INT = ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
+
protected final IndexInput in;
protected final int length;
protected final int dimensions;
protected final int bulkSize;
- MemorySegmentScorer(IndexInput in, int dimensions, int dataLength, int bulkSize, MemorySegment segment) {
+ private byte[] scratch;
+
+ /**
+ * Creates a new MemorySegmentScorer. The index input must be a
+ * {@link MemorySegmentAccessInput} or {@link DirectAccessInput};
+ * otherwise an {@link IllegalArgumentException} is thrown.
+ *
+ * Memory segment access is handled by
+ * {@link org.elasticsearch.simdvec.internal.IndexInputUtils#withSlice
+ * IndexInputUtils.withSlice}, which probes the index input for
+ * {@link MemorySegmentAccessInput} /
+ * {@link DirectAccessInput} support and
+ * falls back to a heap copy when neither is available.
+ *
+ * @param in the index input
+ * @param dimensions the vector dimensions
+ * @param dataLength the length in bytes, per data vector
+ * @param bulkSize the number of vectors per bulk
+ */
+ MemorySegmentScorer(IndexInput in, int dimensions, int dataLength, int bulkSize) {
+ IndexInputUtils.checkInputType(in);
this.in = in;
this.length = dataLength;
this.dimensions = dimensions;
- this.memorySegment = segment;
this.bulkSize = bulkSize;
}
+ protected byte[] getScratch(int len) {
+ if (scratch == null || scratch.length < len) {
+ scratch = new byte[len];
+ }
+ return scratch;
+ }
+
abstract long quantizeScore(byte[] q) throws IOException;
abstract boolean quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException;
@@ -225,6 +255,7 @@ abstract float scoreBulk(
) throws IOException;
protected float applyCorrectionsIndividually(
+ MemorySegment memorySegment,
float queryAdditionalCorrection,
VectorSimilarityFunction similarityFunction,
float centroidDp,
@@ -286,4 +317,5 @@ protected float applyCorrectionsIndividually(
return maxScore;
}
}
+
}
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java
index a87e3626210ce..534b737c96984 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java
@@ -11,7 +11,6 @@
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
-import org.apache.lucene.store.MemorySegmentAccessInput;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
import org.elasticsearch.simdvec.ESNextOSQVectorsScorer;
@@ -19,7 +18,6 @@
import org.elasticsearch.simdvec.internal.MemorySegmentES92Int7VectorsScorer;
import java.io.IOException;
-import java.lang.foreign.MemorySegment;
final class PanamaESVectorizationProvider extends ESVectorizationProvider {
@@ -42,43 +40,30 @@ public ESNextOSQVectorsScorer newESNextOSQVectorsScorer(
int dimension,
int dataLength,
int bulkSize
- ) throws IOException {
- IndexInput unwrappedInput = FilterIndexInput.unwrapOnlyTest(input);
- unwrappedInput = MemorySegmentAccessInputAccess.unwrap(unwrappedInput);
+ ) {
if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS
- && unwrappedInput instanceof MemorySegmentAccessInput msai
&& ((queryBits == 4 && (indexBits == 1 || indexBits == 2 || indexBits == 4)) || (queryBits == 7 && indexBits == 7))) {
- MemorySegment ms = msai.segmentSliceOrNull(0, unwrappedInput.length());
- if (ms != null) {
- return new MemorySegmentESNextOSQVectorsScorer(unwrappedInput, queryBits, indexBits, dimension, dataLength, bulkSize, ms);
- }
+ IndexInput unwrappedInput = FilterIndexInput.unwrapOnlyTest(input);
+ unwrappedInput = MemorySegmentAccessInputAccess.unwrap(unwrappedInput);
+ return new MemorySegmentESNextOSQVectorsScorer(unwrappedInput, queryBits, indexBits, dimension, dataLength, bulkSize);
}
return new ESNextOSQVectorsScorer(input, queryBits, indexBits, dimension, dataLength, bulkSize);
}
@Override
public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension, int bulkSize) throws IOException {
- IndexInput unwrappedInput = FilterIndexInput.unwrapOnlyTest(input);
- unwrappedInput = MemorySegmentAccessInputAccess.unwrap(unwrappedInput);
- if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS && unwrappedInput instanceof MemorySegmentAccessInput msai) {
- MemorySegment ms = msai.segmentSliceOrNull(0, unwrappedInput.length());
- if (ms != null) {
- return new MemorySegmentES91OSQVectorsScorer(unwrappedInput, dimension, bulkSize, ms);
- }
+ if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
+ IndexInput unwrappedInput = FilterIndexInput.unwrapOnlyTest(input);
+ unwrappedInput = MemorySegmentAccessInputAccess.unwrap(unwrappedInput);
+ return new MemorySegmentES91OSQVectorsScorer(unwrappedInput, dimension, bulkSize);
}
return new OnHeapES91OSQVectorsScorer(input, dimension, bulkSize);
}
@Override
- public ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension, int bulkSize) throws IOException {
+ public ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension, int bulkSize) {
IndexInput unwrappedInput = FilterIndexInput.unwrapOnlyTest(input);
unwrappedInput = MemorySegmentAccessInputAccess.unwrap(unwrappedInput);
- if (unwrappedInput instanceof MemorySegmentAccessInput msai) {
- MemorySegment ms = msai.segmentSliceOrNull(0, unwrappedInput.length());
- if (ms != null) {
- return new MemorySegmentES92Int7VectorsScorer(unwrappedInput, dimension, bulkSize, ms);
- }
- }
- return new ES92Int7VectorsScorer(input, dimension, bulkSize);
+ return new MemorySegmentES92Int7VectorsScorer(unwrappedInput, dimension, bulkSize);
}
}
diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java
index aa4f5c355923e..c08afba09c2be 100644
--- a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java
+++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java
@@ -20,8 +20,8 @@ public final class MemorySegmentES92Int7VectorsScorer extends MemorySegmentES92P
private static final boolean NATIVE_SUPPORTED = NativeAccess.instance().getVectorSimilarityFunctions().isPresent();
- public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, int bulkSize, MemorySegment memorySegment) {
- super(in, dimensions, bulkSize, memorySegment);
+ public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, int bulkSize) {
+ super(in, dimensions, bulkSize);
}
@Override
@@ -37,23 +37,22 @@ public long int7DotProduct(byte[] q) throws IOException {
} else {
return panamaInt7DotProduct(q);
}
-
}
private long nativeInt7DotProduct(byte[] q) throws IOException {
- final MemorySegment segment = memorySegment.asSlice(in.getFilePointer(), dimensions);
- final MemorySegment querySegment = MemorySegment.ofArray(q);
- final long res = Similarities.dotProductI7u(segment, querySegment, dimensions);
- in.skipBytes(dimensions);
- return res;
+ return IndexInputUtils.withSlice(in, dimensions, this::getScratch, segment -> {
+ final MemorySegment querySegment = MemorySegment.ofArray(q);
+ return Similarities.dotProductI7u(segment, querySegment, dimensions);
+ });
}
private void nativeInt7DotProductBulk(byte[] q, int count, float[] scores) throws IOException {
- final MemorySegment scoresSegment = MemorySegment.ofArray(scores);
- final MemorySegment segment = memorySegment.asSlice(in.getFilePointer(), dimensions * count);
- final MemorySegment querySegment = MemorySegment.ofArray(q);
- Similarities.dotProductI7uBulk(segment, querySegment, dimensions, count, scoresSegment);
- in.skipBytes(dimensions * count);
+ IndexInputUtils.withSlice(in, (long) dimensions * count, this::getScratch, segment -> {
+ final MemorySegment scoresSegment = MemorySegment.ofArray(scores);
+ final MemorySegment querySegment = MemorySegment.ofArray(q);
+ Similarities.dotProductI7uBulk(segment, querySegment, dimensions, count, scoresSegment);
+ return null;
+ });
}
@Override
diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ESNextOSQVectorsScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ESNextOSQVectorsScorerTests.java
index eb8bd664d0196..0407f55471529 100644
--- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ESNextOSQVectorsScorerTests.java
+++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ESNextOSQVectorsScorerTests.java
@@ -109,6 +109,7 @@ public void testQuantizeScore() throws Exception {
assertEquals(in.getFilePointer(), slice.getFilePointer());
}
assertEquals((long) length * numVectors, slice.getFilePointer());
+ assertEquals((long) length * numVectors, in.getFilePointer());
}
}
}
diff --git a/libs/simdvec/src/test21/java/org/elasticsearch/simdvec/internal/IndexInputUtilsTests.java b/libs/simdvec/src/test21/java/org/elasticsearch/simdvec/internal/IndexInputUtilsTests.java
new file mode 100644
index 0000000000000..5edfaa16a7af7
--- /dev/null
+++ b/libs/simdvec/src/test21/java/org/elasticsearch/simdvec/internal/IndexInputUtilsTests.java
@@ -0,0 +1,185 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+package org.elasticsearch.simdvec.internal;
+
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.FilterIndexInput;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.store.MMapDirectory;
+import org.apache.lucene.store.MemorySegmentAccessInput;
+import org.apache.lucene.store.NIOFSDirectory;
+import org.elasticsearch.core.CheckedConsumer;
+import org.elasticsearch.core.DirectAccessInput;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.lang.foreign.MemorySegment;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+
+/**
+ * Tests that {@link IndexInputUtils#withSlice} correctly handles all
+ * three input types: {@link MemorySegmentAccessInput} (mmap),
+ * {@link DirectAccessInput} (byte-buffer), and plain {@link IndexInput}
+ * (heap-copy fallback).
+ */
+public class IndexInputUtilsTests extends ESTestCase {
+
+ private static final String FILE_NAME = "test.bin";
+
+ // -- withSlice path tests -------------------------------------------------
+
+ public void testWithSliceMemorySegmentAccessInput() throws Exception {
+ byte[] data = randomByteArrayOfLength(256);
+ try (Directory dir = new MMapDirectory(createTempDir())) {
+ writeData(dir, data);
+ try (IndexInput in = dir.openInput(FILE_NAME, IOContext.DEFAULT)) {
+ assertTrue(in instanceof MemorySegmentAccessInput);
+ verifyWithSlice(in, data);
+ }
+ }
+ }
+
+ public void testWithSliceDirectAccessInput() throws Exception {
+ byte[] data = randomByteArrayOfLength(256);
+ try (Directory dir = new NIOFSDirectory(createTempDir())) {
+ writeData(dir, data);
+ try (IndexInput rawIn = dir.openInput(FILE_NAME, IOContext.DEFAULT)) {
+ IndexInput in = new DirectAccessWrapper("dai", rawIn, data);
+ assertTrue(in instanceof DirectAccessInput);
+ verifyWithSlice(in, data);
+ }
+ }
+ }
+
+ public void testWithSlicePlainIndexInput() throws Exception {
+ byte[] data = randomByteArrayOfLength(256);
+ try (Directory dir = new NIOFSDirectory(createTempDir())) {
+ writeData(dir, data);
+ try (IndexInput in = dir.openInput(FILE_NAME, IOContext.DEFAULT)) {
+ assertFalse(in instanceof MemorySegmentAccessInput);
+ assertFalse(in instanceof DirectAccessInput);
+ verifyWithSlice(in, data);
+ }
+ }
+ }
+
+ // -- constructor validation tests -----------------------------------------
+
+ public void testES92ConstructorAcceptsPlainInput() throws Exception {
+ byte[] data = randomByteArrayOfLength(256);
+ try (Directory dir = new NIOFSDirectory(createTempDir())) {
+ writeData(dir, data);
+ try (IndexInput in = dir.openInput(FILE_NAME, IOContext.DEFAULT)) {
+ new MemorySegmentES92Int7VectorsScorer(in, 64, 16);
+ }
+ }
+ }
+
+ public void testES92ConstructorAcceptsMMapInput() throws Exception {
+ byte[] data = randomByteArrayOfLength(256);
+ try (Directory dir = new MMapDirectory(createTempDir())) {
+ writeData(dir, data);
+ try (IndexInput in = dir.openInput(FILE_NAME, IOContext.DEFAULT)) {
+ new MemorySegmentES92Int7VectorsScorer(in, 64, 16);
+ }
+ }
+ }
+
+ public void testES92ConstructorAcceptsDirectAccessInput() throws Exception {
+ byte[] data = randomByteArrayOfLength(256);
+ try (Directory dir = new NIOFSDirectory(createTempDir())) {
+ writeData(dir, data);
+ try (IndexInput rawIn = dir.openInput(FILE_NAME, IOContext.DEFAULT)) {
+ IndexInput in = new DirectAccessWrapper("dai", rawIn, data);
+ new MemorySegmentES92Int7VectorsScorer(in, 64, 16);
+ }
+ }
+ }
+
+ public void testES92ConstructorRejectsUnwrappedFilterIndexInput() throws Exception {
+ byte[] data = randomByteArrayOfLength(256);
+ try (Directory dir = new NIOFSDirectory(createTempDir())) {
+ writeData(dir, data);
+ try (IndexInput rawIn = dir.openInput(FILE_NAME, IOContext.DEFAULT)) {
+ IndexInput wrapped = new FilterIndexInput("plain-wrapper", rawIn) {};
+ expectThrows(IllegalArgumentException.class, () -> new MemorySegmentES92Int7VectorsScorer(wrapped, 64, 16));
+ }
+ }
+ }
+
+ // -- helpers --------------------------------------------------------------
+
+ private void verifyWithSlice(IndexInput in, byte[] expectedData) throws IOException {
+ int firstChunk = 64;
+ byte[] result1 = IndexInputUtils.withSlice(in, firstChunk, byte[]::new, segment -> {
+ byte[] buf = new byte[(int) segment.byteSize()];
+ MemorySegment.ofArray(buf).copyFrom(segment);
+ return buf;
+ });
+ assertArrayEquals(Arrays.copyOfRange(expectedData, 0, firstChunk), result1);
+ assertEquals(firstChunk, in.getFilePointer());
+
+ int secondChunk = 128;
+ byte[] result2 = IndexInputUtils.withSlice(in, secondChunk, byte[]::new, segment -> {
+ byte[] buf = new byte[(int) segment.byteSize()];
+ MemorySegment.ofArray(buf).copyFrom(segment);
+ return buf;
+ });
+ assertArrayEquals(Arrays.copyOfRange(expectedData, firstChunk, firstChunk + secondChunk), result2);
+ assertEquals(firstChunk + secondChunk, in.getFilePointer());
+
+ int remaining = expectedData.length - firstChunk - secondChunk;
+ byte[] result3 = IndexInputUtils.withSlice(in, remaining, byte[]::new, segment -> {
+ byte[] buf = new byte[(int) segment.byteSize()];
+ MemorySegment.ofArray(buf).copyFrom(segment);
+ return buf;
+ });
+ assertArrayEquals(Arrays.copyOfRange(expectedData, firstChunk + secondChunk, expectedData.length), result3);
+ assertEquals(expectedData.length, in.getFilePointer());
+ }
+
+ private static void writeData(Directory dir, byte[] data) throws IOException {
+ try (IndexOutput out = dir.createOutput(FILE_NAME, IOContext.DEFAULT)) {
+ out.writeBytes(data, 0, data.length);
+ }
+ }
+
+ /**
+ * Wraps an existing IndexInput with DirectAccessInput support,
+ * serving byte-buffer slices from the provided data array.
+ */
+ static class DirectAccessWrapper extends FilterIndexInput implements DirectAccessInput {
+ private final byte[] data;
+
+ DirectAccessWrapper(String resourceDescription, IndexInput delegate, byte[] data) {
+ super(resourceDescription, delegate);
+ this.data = data;
+ }
+
+ @Override
+ public boolean withByteBufferSlice(long offset, long length, CheckedConsumer action) throws IOException {
+ ByteBuffer bb = ByteBuffer.wrap(data, (int) offset, (int) length).asReadOnlyBuffer();
+ action.accept(bb);
+ return true;
+ }
+
+ @Override
+ public IndexInput clone() {
+ return new DirectAccessWrapper("clone", in.clone(), data);
+ }
+
+ @Override
+ public IndexInput slice(String sliceDescription, long offset, long length) throws IOException {
+ return new DirectAccessWrapper(sliceDescription, in.slice(sliceDescription, offset, length), data);
+ }
+ }
+}
diff --git a/qa/vector/build.gradle b/qa/vector/build.gradle
index c078b80ed04f9..95390d3a68de0 100644
--- a/qa/vector/build.gradle
+++ b/qa/vector/build.gradle
@@ -80,6 +80,11 @@ dependencies {
implementation project(':libs:logging')
implementation project(':server')
implementation project(':libs:gpu-codec')
+ implementation(project(':x-pack:plugin:searchable-snapshots')) {
+ capabilities {
+ requireCapability("org.elasticsearch.plugin:searchable-snapshots-test-artifacts")
+ }
+ }
testImplementation project(":libs:x-content")
testImplementation project(":test:framework")
@@ -133,48 +138,6 @@ tasks.named("thirdPartyAudit").configure {
'com.google.appengine.api.urlfetch.URLFetchService',
'com.google.appengine.api.urlfetch.URLFetchServiceFactory',
- // optional apache http client dependencies
- 'org.apache.http.ConnectionReuseStrategy',
- 'org.apache.http.Header',
- 'org.apache.http.HttpEntity',
- 'org.apache.http.HttpEntityEnclosingRequest',
- 'org.apache.http.HttpHost',
- 'org.apache.http.HttpRequest',
- 'org.apache.http.HttpResponse',
- 'org.apache.http.HttpVersion',
- 'org.apache.http.RequestLine',
- 'org.apache.http.StatusLine',
- 'org.apache.http.client.AuthenticationHandler',
- 'org.apache.http.client.HttpClient',
- 'org.apache.http.client.HttpRequestRetryHandler',
- 'org.apache.http.client.RedirectHandler',
- 'org.apache.http.client.RequestDirector',
- 'org.apache.http.client.UserTokenHandler',
- 'org.apache.http.client.methods.HttpEntityEnclosingRequestBase',
- 'org.apache.http.client.methods.HttpRequestBase',
- 'org.apache.http.config.Registry',
- 'org.apache.http.config.RegistryBuilder',
- 'org.apache.http.conn.ClientConnectionManager',
- 'org.apache.http.conn.ConnectionKeepAliveStrategy',
- 'org.apache.http.conn.params.ConnManagerParams',
- 'org.apache.http.conn.params.ConnRouteParams',
- 'org.apache.http.conn.routing.HttpRoutePlanner',
- 'org.apache.http.conn.scheme.PlainSocketFactory',
- 'org.apache.http.conn.scheme.SchemeRegistry',
- 'org.apache.http.conn.socket.PlainConnectionSocketFactory',
- 'org.apache.http.conn.ssl.SSLSocketFactory',
- 'org.apache.http.conn.ssl.X509HostnameVerifier',
- 'org.apache.http.entity.AbstractHttpEntity',
- 'org.apache.http.impl.client.DefaultHttpClient',
- 'org.apache.http.impl.client.HttpClientBuilder',
- 'org.apache.http.impl.conn.PoolingHttpClientConnectionManager',
- 'org.apache.http.params.HttpConnectionParams',
- 'org.apache.http.params.HttpParams',
- 'org.apache.http.params.HttpProtocolParams',
- 'org.apache.http.protocol.HttpContext',
- 'org.apache.http.protocol.HttpProcessor',
- 'org.apache.http.protocol.HttpRequestExecutor',
-
// grpc/proto stuff
'com.google.api.gax.grpc.GrpcCallContext',
'com.google.api.gax.grpc.GrpcCallSettings',
@@ -219,7 +182,7 @@ tasks.named("thirdPartyAudit").configure {
'io.opentelemetry.sdk.metrics.export.PeriodicMetricReaderBuilder',
'io.opentelemetry.sdk.resources.Resource',
)
-
+
if (buildParams.graalVmRuntime == false) {
ignoreMissingClasses(
'org.graalvm.nativeimage.hosted.Feature',
diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java
index 7452eb328fd8c..7346b77abd46d 100644
--- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java
+++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java
@@ -25,7 +25,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.TieredMergePolicy;
-import org.apache.lucene.store.FSDirectory;
+import org.apache.lucene.store.Directory;
import org.apache.lucene.util.NamedThreadFactory;
import org.elasticsearch.cli.ProcessInfo;
import org.elasticsearch.common.Strings;
@@ -63,6 +63,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -122,6 +123,44 @@ enum MergePolicyType {
LOG_DOC
}
+ /**
+ * Factory that creates a directory for a given index path.
+ */
+ @FunctionalInterface
+ interface DirectoryFactory {
+ Directory create(Path indexPath) throws IOException;
+ }
+
+ record DirectoryTypeConfig(DirectoryFactory factory, boolean shared, boolean preWarm) {}
+
+ private static final Map directoryTypeRegistry = new ConcurrentHashMap<>();
+
+ static {
+ directoryTypeRegistry.put("default", new DirectoryTypeConfig(KnnIndexer::getDirectory, false, false));
+ directoryTypeRegistry.put("frozen", new DirectoryTypeConfig(KnnIndexer::openFrozenDirectory, false, true));
+ }
+
+ /**
+ * Registers a custom directory type that can be referenced via {@code "directory_type"}
+ * in the test configuration JSON.
+ *
+ * @param name the name used in configuration (e.g. "serverless")
+ * @param factory creates a Directory for the given index path
+ * @param shared if true, a single directory instance is used for both write and read phases
+ * @param preWarm if true, the directory is pre-warmed before search
+ */
+ static void registerDirectoryType(String name, DirectoryFactory factory, boolean shared, boolean preWarm) {
+ directoryTypeRegistry.put(name, new DirectoryTypeConfig(factory, shared, preWarm));
+ }
+
+ static DirectoryTypeConfig getDirectoryTypeConfig(String name) {
+ DirectoryTypeConfig config = directoryTypeRegistry.get(name);
+ if (config == null) {
+ throw new IllegalArgumentException("Unknown directory_type: '" + name + "'. Known types: " + directoryTypeRegistry.keySet());
+ }
+ return config;
+ }
+
private static String formatIndexPath(TestConfiguration args) {
List suffix = new ArrayList<>();
switch (args.indexType()) {
@@ -347,49 +386,20 @@ public static void main(String[] args) throws Exception {
Codec codec = createCodec(testConfiguration, exec);
Path indexPath = PathUtils.get(indexPathName);
MergePolicy mergePolicy = getMergePolicy(testConfiguration);
- if (testConfiguration.reindex() || testConfiguration.forceMerge()) {
- KnnIndexer knnIndexer = new KnnIndexer(
- testConfiguration.docVectors(),
- indexPath,
- codec,
- testConfiguration.indexThreads(),
- testConfiguration.vectorEncoding().luceneEncoding,
- testConfiguration.dimensions(),
- testConfiguration.vectorSpace(),
- testConfiguration.numDocs(),
- mergePolicy,
- testConfiguration.writerBufferSizeInMb(),
- testConfiguration.writerMaxBufferedDocs()
- );
- if (testConfiguration.reindex() == false && Files.exists(indexPath) == false) {
- throw new IllegalArgumentException("Index path does not exist: " + indexPath);
- }
- if (testConfiguration.reindex()) {
- knnIndexer.createIndex(indexResults);
- }
- if (testConfiguration.forceMerge()) {
- knnIndexer.forceMerge(indexResults, testConfiguration.forceMergeMaxNumSegments());
- }
- }
- numSegments(indexPath, indexResults);
- if (testConfiguration.queryVectors() != null && testConfiguration.numQueries() > 0) {
- if (parsedArgs.warmUpIterations() > 0) {
- logger.info("Running the searches for " + parsedArgs.warmUpIterations() + " warm up iterations");
- }
- // Warm up
- for (int warmUpCount = 0; warmUpCount < parsedArgs.warmUpIterations(); warmUpCount++) {
- for (int i = 0; i < results.length; i++) {
- var ignoreResults = new Results(indexPathName, indexType, testConfiguration.numDocs());
- KnnSearcher knnSearcher = new KnnSearcher(indexPath, testConfiguration);
- knnSearcher.runSearch(ignoreResults, testConfiguration.searchParams().get(i));
- }
- }
-
- for (int i = 0; i < results.length; i++) {
- KnnSearcher knnSearcher = new KnnSearcher(indexPath, testConfiguration);
- knnSearcher.runSearch(results[i], testConfiguration.searchParams().get(i));
- }
- }
+ DirectoryTypeConfig dirConfig = getDirectoryTypeConfig(testConfiguration.directoryType());
+
+ runTestConfiguration(
+ testConfiguration,
+ indexPath,
+ codec,
+ mergePolicy,
+ dirConfig,
+ indexResults,
+ results,
+ parsedArgs,
+ indexPathName,
+ indexType
+ );
formattedResults.queryResults.addAll(List.of(results));
formattedResults.indexResults.add(indexResults);
} finally {
@@ -401,6 +411,120 @@ public static void main(String[] args) throws Exception {
logger.info("Results: \n" + formattedResults);
}
+ /**
+ * Runs indexing, merge, and search phases using the given directory configuration.
+ * When {@code dirConfig.shared()} is true, a single directory instance is used for all
+ * phases. Otherwise, separate directories are used for write and read.
+ */
+ private static void runTestConfiguration(
+ TestConfiguration testConfiguration,
+ Path indexPath,
+ Codec codec,
+ MergePolicy mergePolicy,
+ DirectoryTypeConfig dirConfig,
+ Results indexResults,
+ Results[] results,
+ ParsedArgs parsedArgs,
+ String indexPathName,
+ String indexType
+ ) throws Exception {
+ Directory sharedDir = dirConfig.shared() ? dirConfig.factory().create(indexPath) : null;
+ try {
+ if (testConfiguration.reindex() || testConfiguration.forceMerge()) {
+ KnnIndexer knnIndexer = new KnnIndexer(
+ testConfiguration.docVectors(),
+ indexPath,
+ codec,
+ testConfiguration.indexThreads(),
+ testConfiguration.vectorEncoding().luceneEncoding,
+ testConfiguration.dimensions(),
+ testConfiguration.vectorSpace(),
+ testConfiguration.numDocs(),
+ mergePolicy,
+ testConfiguration.writerBufferSizeInMb(),
+ testConfiguration.writerMaxBufferedDocs()
+ );
+ if (testConfiguration.reindex() == false && Files.exists(indexPath) == false) {
+ throw new IllegalArgumentException("Index path does not exist: " + indexPath);
+ }
+ if (testConfiguration.reindex()) {
+ reindex(knnIndexer, indexResults, sharedDir);
+ }
+ if (testConfiguration.forceMerge()) {
+ forceMerge(knnIndexer, indexResults, sharedDir, testConfiguration);
+ }
+ }
+ numSegments(indexPath, indexResults, sharedDir);
+ if (testConfiguration.queryVectors() != null && testConfiguration.numQueries() > 0) {
+ Directory readDir = sharedDir != null ? sharedDir : dirConfig.factory().create(indexPath);
+ try {
+ if (dirConfig.preWarm()) {
+ KnnSearcher.preWarmDirectory(readDir);
+ }
+ runSearches(testConfiguration, indexPath, readDir, results, parsedArgs, indexPathName, indexType);
+ } finally {
+ if (sharedDir == null) {
+ readDir.close();
+ }
+ }
+ }
+ } finally {
+ if (sharedDir != null) {
+ sharedDir.close();
+ }
+ }
+ }
+
+ static void reindex(KnnIndexer knnIndexer, Results indexResults, Directory sharedDir) throws Exception {
+ if (sharedDir != null) {
+ knnIndexer.createIndex(indexResults, sharedDir);
+ } else {
+ knnIndexer.createIndex(indexResults);
+ }
+ }
+
+ static void forceMerge(KnnIndexer knnIndexer, Results indexResults, Directory sharedDir, TestConfiguration testConfiguration)
+ throws Exception {
+ if (sharedDir != null) {
+ knnIndexer.forceMerge(indexResults, testConfiguration.forceMergeMaxNumSegments(), sharedDir);
+ } else {
+ knnIndexer.forceMerge(indexResults, testConfiguration.forceMergeMaxNumSegments());
+ }
+ }
+
+ static void numSegments(Path indexPath, Results indexResults, Directory sharedDir) throws IOException {
+ if (sharedDir != null) {
+ numSegments(sharedDir, indexResults);
+ } else {
+ numSegments(indexPath, indexResults);
+ }
+ }
+
+ private static void runSearches(
+ TestConfiguration testConfiguration,
+ Path indexPath,
+ Directory dir,
+ Results[] results,
+ ParsedArgs parsedArgs,
+ String indexPathName,
+ String indexType
+ ) throws Exception {
+ if (parsedArgs.warmUpIterations() > 0) {
+ logger.info("Running the searches for " + parsedArgs.warmUpIterations() + " warm up iterations");
+ }
+ for (int warmUpCount = 0; warmUpCount < parsedArgs.warmUpIterations(); warmUpCount++) {
+ for (int i = 0; i < results.length; i++) {
+ var ignoreResults = new Results(indexPathName, indexType, testConfiguration.numDocs());
+ KnnSearcher knnSearcher = new KnnSearcher(indexPath, testConfiguration);
+ knnSearcher.runSearch(ignoreResults, testConfiguration.searchParams().get(i), dir);
+ }
+ }
+ for (int i = 0; i < results.length; i++) {
+ KnnSearcher knnSearcher = new KnnSearcher(indexPath, testConfiguration);
+ knnSearcher.runSearch(results[i], testConfiguration.searchParams().get(i), dir);
+ }
+ }
+
private static void checkQuantizeBits(TestConfiguration args) {
switch (args.indexType()) {
case IVF:
@@ -431,13 +555,21 @@ private static MergePolicy getMergePolicy(TestConfiguration args) {
}
static void numSegments(Path indexPath, Results result) throws IOException {
- try (FSDirectory dir = FSDirectory.open(indexPath); IndexReader reader = DirectoryReader.open(dir)) {
+ try (Directory dir = KnnIndexer.getDirectory(indexPath); IndexReader reader = DirectoryReader.open(dir)) {
result.numSegments = reader.leaves().size();
} catch (IOException e) {
throw new IOException("Failed to get segment count for index at " + indexPath, e);
}
}
+ static void numSegments(Directory dir, Results result) throws IOException {
+ try (IndexReader reader = DirectoryReader.open(dir)) {
+ result.numSegments = reader.leaves().size();
+ } catch (IOException e) {
+ throw new IOException("Failed to get segment count for dir: " + dir, e);
+ }
+ }
+
static class FormattedResults {
List indexResults = new ArrayList<>();
List queryResults = new ArrayList<>();
diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java
index 491d3c8e553bd..98d629889feb8 100644
--- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java
+++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java
@@ -108,6 +108,12 @@ class KnnIndexer {
}
void createIndex(KnnIndexTester.Results result) throws IOException, InterruptedException, ExecutionException {
+ try (Directory dir = getDirectory(indexPath)) {
+ createIndex(result, dir);
+ }
+ }
+
+ void createIndex(KnnIndexTester.Results result, Directory dir) throws IOException, InterruptedException, ExecutionException {
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE);
iwc.setCodec(codec);
iwc.setMaxBufferedDocs(writerMaxBufferedDocs);
@@ -140,7 +146,7 @@ public boolean isEnabled(String component) {
long start = System.nanoTime();
AtomicInteger numDocsIndexed = new AtomicInteger();
- try (Directory dir = getDirectory(indexPath); IndexWriter iw = new IndexWriter(dir, iwc)) {
+ try (IndexWriter iw = new IndexWriter(dir, iwc)) {
for (Path docsPath : this.docsPath) {
int dim = this.dim;
try (FileChannel in = FileChannel.open(docsPath)) {
@@ -223,6 +229,12 @@ public boolean isEnabled(String component) {
}
void forceMerge(KnnIndexTester.Results results, int maxNumSegments) throws Exception {
+ try (Directory dir = getDirectory(indexPath)) {
+ forceMerge(results, maxNumSegments, dir);
+ }
+ }
+
+ void forceMerge(KnnIndexTester.Results results, int maxNumSegments, Directory dir) throws Exception {
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.APPEND);
iwc.setInfoStream(new PrintStreamInfoStream(System.out) {
@Override
@@ -234,7 +246,7 @@ public boolean isEnabled(String component) {
iwc.setUseCompoundFile(false);
logger.info("KnnIndexer: forceMerge in {} into {} segments", indexPath, maxNumSegments);
long startNS = System.nanoTime();
- try (IndexWriter iw = new IndexWriter(getDirectory(indexPath), iwc)) {
+ try (IndexWriter iw = new IndexWriter(dir, iwc)) {
iw.forceMerge(maxNumSegments);
}
long endNS = System.nanoTime();
@@ -252,6 +264,35 @@ static Directory getDirectory(Path indexPath) throws IOException {
return dir;
}
+ /**
+ * Opens a frozen (searchable snapshot) directory for the given index path.
+ */
+ static Directory openFrozenDirectory(Path indexPath) throws IOException {
+ Path workPath = indexPath.resolveSibling(indexPath.getFileName() + ".snap_work");
+ Files.createDirectories(workPath);
+ logger.info("Opening frozen snapshot directory for index at {} with work path {}", indexPath, workPath);
+ return openSearchableSnapshotDirectory(indexPath, workPath);
+ }
+
+ /**
+ * Creates a directory backed by searchable snapshot infrastructure, wrapping an existing
+ * Lucene index on disk. Loaded via reflection because the factory resides in the
+ * searchable-snapshots test artifact (unnamed module) which cannot be directly referenced
+ * from this named module ({@code org.elasticsearch.test.knn}).
+ */
+ private static Directory openSearchableSnapshotDirectory(Path indexPath, Path workPath) throws IOException {
+ try {
+ Class> factoryClass = Class.forName("org.elasticsearch.xpack.searchablesnapshots.store.SearchableSnapshotDirectoryFactory");
+ var method = factoryClass.getMethod("newDirectoryFromIndex", Path.class, Path.class);
+ return (Directory) method.invoke(null, indexPath, workPath);
+ } catch (Exception e) {
+ throw new IOException(
+ "Failed to create searchable snapshot directory. Ensure the searchable-snapshots test artifact is on the classpath.",
+ e
+ );
+ }
+ }
+
private static BiFunction> getReadAdviceFunc() {
return (name, context) -> {
if (context.hints().contains(StandardIOBehaviorHint.INSTANCE) || name.endsWith(".cfs")) {
diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java
index 26a0547015040..b89c80ab66d8b 100644
--- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java
+++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java
@@ -51,10 +51,13 @@
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.FixedBitSet;
import org.elasticsearch.common.io.Channels;
+import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.PathUtils;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.search.profile.query.QueryProfiler;
@@ -123,7 +126,7 @@ class KnnSearcher {
this.doPrecondition = testConfiguration.doPrecondition();
}
- void runSearch(KnnIndexTester.Results finalResults, SearchParameters searchParameters) throws IOException {
+ void runSearch(KnnIndexTester.Results finalResults, SearchParameters searchParameters, Directory dir) throws IOException {
Query filterQuery = searchParameters.filterSelectivity() < 1f
? generateRandomQuery(
new Random(searchParameters.seed()),
@@ -173,124 +176,122 @@ void runSearch(KnnIndexTester.Results finalResults, SearchParameters searchParam
);
KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding, offsetByteSize);
long startNS;
- try (Directory dir = KnnIndexer.getDirectory(indexPath)) {
- try (DirectoryReader reader = DirectoryReader.open(dir)) {
- IndexSearcher searcher = searchParameters.searchThreads() > 1
- ? new IndexSearcher(reader, executorService)
- : new IndexSearcher(reader);
- byte[] targetBytes = new byte[dim];
- float[] target = new float[dim];
- // warm up
- for (int i = 0; i < numQueryVectors; i++) {
- if (vectorEncoding.equals(VectorEncoding.BYTE)) {
- targetReader.next(targetBytes);
- doVectorQuery(targetBytes, searcher, filterQuery, searchParameters);
- } else {
- targetReader.next(target);
- doVectorQuery(target, searcher, filterQuery, searchParameters);
- }
- }
- targetReader.reset();
- final IntConsumer[] queryConsumers = new IntConsumer[searchParameters.numSearchers()];
+ try (DirectoryReader reader = DirectoryReader.open(dir)) {
+ IndexSearcher searcher = searchParameters.searchThreads() > 1
+ ? new IndexSearcher(reader, executorService)
+ : new IndexSearcher(reader);
+ byte[] targetBytes = new byte[dim];
+ float[] target = new float[dim];
+ // warm up
+ for (int i = 0; i < numQueryVectors; i++) {
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
- byte[][] queries = new byte[numQueryVectors][dim];
- for (int i = 0; i < numQueryVectors; i++) {
- targetReader.next(queries[i]);
- }
- for (int s = 0; s < searchParameters.numSearchers(); s++) {
- queryConsumers[s] = i -> {
- try {
- results[i] = doVectorQuery(queries[i], searcher, filterQuery, searchParameters);
- } catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- };
- }
+ targetReader.next(targetBytes);
+ doVectorQuery(targetBytes, searcher, filterQuery, searchParameters);
} else {
- float[][] queries = new float[numQueryVectors][dim];
- for (int i = 0; i < numQueryVectors; i++) {
- targetReader.next(queries[i]);
- }
- for (int s = 0; s < searchParameters.numSearchers(); s++) {
- queryConsumers[s] = i -> {
- try {
- results[i] = doVectorQuery(queries[i], searcher, filterQuery, searchParameters);
- } catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- };
- }
+ targetReader.next(target);
+ doVectorQuery(target, searcher, filterQuery, searchParameters);
+ }
+ }
+ targetReader.reset();
+ final IntConsumer[] queryConsumers = new IntConsumer[searchParameters.numSearchers()];
+ if (vectorEncoding.equals(VectorEncoding.BYTE)) {
+ byte[][] queries = new byte[numQueryVectors][dim];
+ for (int i = 0; i < numQueryVectors; i++) {
+ targetReader.next(queries[i]);
}
- int[][] querySplits = new int[searchParameters.numSearchers()][];
- int queriesPerSearcher = numQueryVectors / searchParameters.numSearchers();
for (int s = 0; s < searchParameters.numSearchers(); s++) {
- int start = s * queriesPerSearcher;
- int end = (s == searchParameters.numSearchers() - 1) ? numQueryVectors : (s + 1) * queriesPerSearcher;
- querySplits[s] = new int[end - start];
- for (int i = start; i < end; i++) {
- querySplits[s][i - start] = i;
- }
+ queryConsumers[s] = i -> {
+ try {
+ results[i] = doVectorQuery(queries[i], searcher, filterQuery, searchParameters);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ };
}
- targetReader.reset();
- startNS = System.nanoTime();
- KnnIndexTester.ThreadDetails startThreadDetails = new KnnIndexTester.ThreadDetails();
- if (numSearchersExecutor != null) {
- // use multiple searchers
- var futures = new ArrayList>();
- for (int s = 0; s < searchParameters.numSearchers(); s++) {
- int[] split = querySplits[s];
- IntConsumer queryConsumer = queryConsumers[s];
- futures.add(numSearchersExecutor.submit(() -> {
- for (int j : split) {
- queryConsumer.accept(j);
- }
- return null;
- }));
- }
- for (Future future : futures) {
+ } else {
+ float[][] queries = new float[numQueryVectors][dim];
+ for (int i = 0; i < numQueryVectors; i++) {
+ targetReader.next(queries[i]);
+ }
+ for (int s = 0; s < searchParameters.numSearchers(); s++) {
+ queryConsumers[s] = i -> {
try {
- future.get();
- } catch (Exception e) {
- throw new RuntimeException("Error executing searcher thread", e);
+ results[i] = doVectorQuery(queries[i], searcher, filterQuery, searchParameters);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
}
- }
- } else {
- // use a single searcher
- for (int i = 0; i < numQueryVectors; i++) {
- queryConsumers[0].accept(i);
- }
+ };
}
- KnnIndexTester.ThreadDetails endThreadDetails = new KnnIndexTester.ThreadDetails();
- elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS);
- long startCPUTimeNS = 0;
- long endCPUTimeNS = 0;
- for (int i = 0; i < startThreadDetails.threadInfos.length; i++) {
- if (startThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher")) {
- startCPUTimeNS += startThreadDetails.cpuTimesNS[i];
- }
+ }
+ int[][] querySplits = new int[searchParameters.numSearchers()][];
+ int queriesPerSearcher = numQueryVectors / searchParameters.numSearchers();
+ for (int s = 0; s < searchParameters.numSearchers(); s++) {
+ int start = s * queriesPerSearcher;
+ int end = (s == searchParameters.numSearchers() - 1) ? numQueryVectors : (s + 1) * queriesPerSearcher;
+ querySplits[s] = new int[end - start];
+ for (int i = start; i < end; i++) {
+ querySplits[s][i - start] = i;
}
-
- for (int i = 0; i < endThreadDetails.threadInfos.length; i++) {
- if (endThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher")) {
- endCPUTimeNS += endThreadDetails.cpuTimesNS[i];
+ }
+ targetReader.reset();
+ startNS = System.nanoTime();
+ KnnIndexTester.ThreadDetails startThreadDetails = new KnnIndexTester.ThreadDetails();
+ if (numSearchersExecutor != null) {
+ // use multiple searchers
+ var futures = new ArrayList>();
+ for (int s = 0; s < searchParameters.numSearchers(); s++) {
+ int[] split = querySplits[s];
+ IntConsumer queryConsumer = queryConsumers[s];
+ futures.add(numSearchersExecutor.submit(() -> {
+ for (int j : split) {
+ queryConsumer.accept(j);
+ }
+ return null;
+ }));
+ }
+ for (Future future : futures) {
+ try {
+ future.get();
+ } catch (Exception e) {
+ throw new RuntimeException("Error executing searcher thread", e);
}
}
- totalCpuTimeMS = TimeUnit.NANOSECONDS.toMillis(endCPUTimeNS - startCPUTimeNS);
-
- // Fetch, validate and write result document ids.
- StoredFields storedFields = reader.storedFields();
+ } else {
+ // use a single searcher
for (int i = 0; i < numQueryVectors; i++) {
- totalVisited += results[i].totalHits.value();
- resultIds[i] = getResultIds(results[i], storedFields);
+ queryConsumers[0].accept(i);
+ }
+ }
+ KnnIndexTester.ThreadDetails endThreadDetails = new KnnIndexTester.ThreadDetails();
+ elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS);
+ long startCPUTimeNS = 0;
+ long endCPUTimeNS = 0;
+ for (int i = 0; i < startThreadDetails.threadInfos.length; i++) {
+ if (startThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher")) {
+ startCPUTimeNS += startThreadDetails.cpuTimesNS[i];
+ }
+ }
+
+ for (int i = 0; i < endThreadDetails.threadInfos.length; i++) {
+ if (endThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher")) {
+ endCPUTimeNS += endThreadDetails.cpuTimesNS[i];
}
- logger.info(
- "completed {} searches in {} ms: {} QPS CPU time={}ms",
- numQueryVectors,
- elapsed,
- (1000L * numQueryVectors) / elapsed,
- totalCpuTimeMS
- );
}
+ totalCpuTimeMS = TimeUnit.NANOSECONDS.toMillis(endCPUTimeNS - startCPUTimeNS);
+
+ // Fetch, validate and write result document ids.
+ StoredFields storedFields = reader.storedFields();
+ for (int i = 0; i < numQueryVectors; i++) {
+ totalVisited += results[i].totalHits.value();
+ resultIds[i] = getResultIds(results[i], storedFields);
+ }
+ logger.info(
+ "completed {} searches in {} ms: {} QPS CPU time={}ms",
+ numQueryVectors,
+ elapsed,
+ (1000L * numQueryVectors) / elapsed,
+ totalCpuTimeMS
+ );
}
}
logger.info("checking results");
@@ -309,6 +310,35 @@ void runSearch(KnnIndexTester.Results finalResults, SearchParameters searchParam
finalResults.earlyTermination = searchParameters.earlyTermination();
}
+ /**
+ * Pre-warms the searchable snapshot cache by sequentially reading every
+ * file in the directory through a single {@link IndexInput} per file.
+ */
+ static void preWarmDirectory(Directory dir) throws IOException {
+ long startNS = System.nanoTime();
+ long totalBytes = 0;
+ byte[] buf = new byte[64 * 1024];
+ for (String file : dir.listAll()) {
+ long fileLength = dir.fileLength(file);
+ try (IndexInput in = dir.openInput(file, IOContext.READONCE)) {
+ long remaining = fileLength;
+ while (remaining > 0) {
+ int toRead = (int) Math.min(buf.length, remaining);
+ in.readBytes(buf, 0, toRead);
+ remaining -= toRead;
+ }
+ }
+ totalBytes += fileLength;
+ }
+ long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS);
+ logger.info(
+ "Pre-warmed searchable snapshot cache: {} across {} files in {} ms",
+ ByteSizeValue.ofBytes(totalBytes),
+ dir.listAll().length,
+ elapsedMS
+ );
+ }
+
private static Query generateRandomQuery(Random random, Path indexPath, int size, float selectivity, boolean filterCached)
throws IOException {
FixedBitSet bitSet = new FixedBitSet(size);
diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java
index 151eb913715b9..7f41bb42b7316 100644
--- a/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java
+++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java
@@ -69,7 +69,8 @@ record TestConfiguration(
boolean doPrecondition,
int preconditioningBlockDims,
int flatVectorThreshold,
- int secondaryClusterSize
+ int secondaryClusterSize,
+ String directoryType
) {
static final ParseField DATASET_FIELD = new ParseField("dataset");
@@ -110,6 +111,7 @@ record TestConfiguration(
static final ParseField FILTER_CACHED = new ParseField("filter_cache");
static final ParseField SEARCH_PARAMS = new ParseField("search_params");
static final ParseField FLAT_VECTOR_THRESHOLD = new ParseField("flat_vector_threshold");
+ static final ParseField DIRECTORY_TYPE_FIELD = new ParseField("directory_type");
/** By default, in ES the default writer buffer size is 10% of the heap space
* (see {@code IndexingMemoryController.INDEX_BUFFER_SIZE_SETTING}).
@@ -174,6 +176,7 @@ static TestConfiguration fromXContent(XContentParser parser) throws Exception {
PARSER.declareInt(Builder::setMergeWorkers, MERGE_WORKERS_FIELD);
PARSER.declareInt(Builder::setFlatVectorThreshold, FLAT_VECTOR_THRESHOLD);
PARSER.declareInt(Builder::setSecondaryClusterSize, SECONDARY_CLUSTER_SIZE);
+ PARSER.declareString(Builder::setDirectoryType, DIRECTORY_TYPE_FIELD);
}
public int numberOfSearchRuns() {
@@ -230,6 +233,11 @@ public static String formattedParameterHelp() {
"search_params",
"array[object]",
"Explicit per-search settings; each object may include search fields like num_candidates, k, and visit_percentage."
+ ),
+ new ParameterHelp(
+ "directory_type",
+ "string",
+ "Directory type: default (mmap), frozen (searchable snapshot), or custom types registered by external wrappers."
)
);
@@ -303,6 +311,7 @@ static class Builder implements ToXContentObject {
private int numMergeWorkers = 1;
private int flatVectorThreshold = -1; // -1 mean use default (vectorPerCluster * 3)
private int secondaryClusterSize = -1;
+ private String directoryType = "default";
/**
* Elasticsearch does not set this explicitly, and in Lucene this setting is
@@ -504,6 +513,11 @@ public Builder setSecondaryClusterSize(int secondaryClusterSize) {
return this;
}
+ public Builder setDirectoryType(String directoryType) {
+ this.directoryType = directoryType.toLowerCase(Locale.ROOT);
+ return this;
+ }
+
/*
* Each dataset has a descriptor file, expected to be at gs:////.json, with contents of:
{
@@ -721,7 +735,8 @@ public TestConfiguration build() throws Exception {
doPrecondition,
preconditioningBlockDims,
flatVectorThreshold,
- secondaryClusterSize
+ secondaryClusterSize,
+ directoryType
);
}
@@ -779,6 +794,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(SEARCH_PARAMS.getPreferredName(), searchParams);
}
builder.field(FLAT_VECTOR_THRESHOLD.getPreferredName(), flatVectorThreshold);
+ builder.field(DIRECTORY_TYPE_FIELD.getPreferredName(), directoryType);
return builder.endObject();
}
diff --git a/server/src/main/java/org/elasticsearch/index/store/StoreMetricsIndexInput.java b/server/src/main/java/org/elasticsearch/index/store/StoreMetricsIndexInput.java
index ede0a44fd7a53..a91a7b335fd97 100644
--- a/server/src/main/java/org/elasticsearch/index/store/StoreMetricsIndexInput.java
+++ b/server/src/main/java/org/elasticsearch/index/store/StoreMetricsIndexInput.java
@@ -14,14 +14,17 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.store.RandomAccessInput;
+import org.elasticsearch.core.CheckedConsumer;
+import org.elasticsearch.core.DirectAccessInput;
import org.elasticsearch.simdvec.MemorySegmentAccessInputAccess;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
-public class StoreMetricsIndexInput extends FilterIndexInput {
+public class StoreMetricsIndexInput extends FilterIndexInput implements DirectAccessInput {
final PluggableDirectoryMetricsHolder metricHolder;
public static IndexInput create(String resourceDescription, IndexInput in, PluggableDirectoryMetricsHolder metricHolder) {
@@ -88,6 +91,14 @@ public void prefetch(long offset, long length) throws IOException {
in.prefetch(offset, length);
}
+ @Override
+ public boolean withByteBufferSlice(long offset, long length, CheckedConsumer action) throws IOException {
+ if (in instanceof DirectAccessInput dai) {
+ return dai.withByteBufferSlice(offset, length, action);
+ }
+ return false;
+ }
+
@Override
public Optional isLoaded() {
return in.isLoaded();
diff --git a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java
index fa2fe1bc5628c..b3594022c4e3c 100644
--- a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java
+++ b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java
@@ -33,6 +33,7 @@
import org.elasticsearch.common.util.concurrent.ThrottledTaskRunner;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.Assertions;
+import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
@@ -43,6 +44,7 @@
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.index.store.LuceneFilesExtensions;
import org.elasticsearch.monitor.fs.FsProbe;
+import org.elasticsearch.nativeaccess.CloseableByteBuffer;
import org.elasticsearch.node.NodeRoleSettings;
import org.elasticsearch.threadpool.ThreadPool;
@@ -1064,6 +1066,35 @@ boolean tryRead(ByteBuffer buf, long offset) throws IOException {
}
}
+ /**
+ * Optimistically try to get a direct ByteBuffer slice from the region.
+ * The returned {@link CloseableByteBuffer} holds a reference to this region,
+ * preventing eviction while the buffer is in use. The caller must close it
+ * when done.
+ * @return a CloseableByteBuffer wrapping a read-only ByteBuffer slice, or null if not available
+ */
+ CloseableByteBuffer tryGetByteBufferSlice(long offset, int length) {
+ SharedBytes.IO ioRef = nonVolatileIO();
+ if (ioRef != null && tryIncRef()) {
+ ByteBuffer slice = ioRef.byteBufferSlice(blobCacheService.getRegionRelativePosition(offset), length);
+ if (slice != null && isEvicted() == false) {
+ return new CloseableByteBuffer() {
+ @Override
+ public ByteBuffer buffer() {
+ return slice;
+ }
+
+ @Override
+ public void close() {
+ CacheFileRegion.this.decRef();
+ }
+ };
+ }
+ decRef();
+ }
+ return null;
+ }
+
/**
* Populates a range in cache if the range is not available nor pending to be available in cache.
*
@@ -1362,6 +1393,49 @@ public boolean tryRead(ByteBuffer buf, long offset) throws IOException {
return res;
}
+ CloseableByteBuffer tryGetByteBufferSlice(long offset, int length) {
+ assert assertOffsetsWithinFileLength(offset, length, this.length);
+ final int startRegion = getRegion(offset);
+ final long end = offset + length;
+ final int endRegion = getEndingRegion(end);
+ if (startRegion != endRegion) {
+ return null;
+ }
+ var fileRegion = lastAccessedRegion;
+ if (fileRegion != null && fileRegion.chunk.regionKey.region == startRegion) {
+ fileRegion.touch();
+ } else {
+ fileRegion = cache.get(cacheKey, this.length, startRegion);
+ }
+ final var region = fileRegion.chunk;
+ if (region.tracker.checkAvailable(end - getRegionStart(startRegion)) == false) {
+ return null;
+ }
+ CloseableByteBuffer slice = region.tryGetByteBufferSlice(offset, length);
+ if (slice != null) {
+ lastAccessedRegion = fileRegion;
+ }
+ return slice;
+ }
+
+ /**
+ * If a direct byte buffer view is available for the given range, passes it
+ * to {@code action} and returns {@code true}. Otherwise returns
+ * {@code false} without invoking the action.
+ */
+ public boolean withByteBufferSlice(long offset, int length, CheckedConsumer action) throws IOException {
+ CloseableByteBuffer cbb = tryGetByteBufferSlice(offset, length);
+ if (cbb == null) {
+ return false;
+ }
+ try {
+ action.accept(cbb.buffer());
+ return true;
+ } finally {
+ cbb.close();
+ }
+ }
+
public int populateAndRead(
final ByteRange rangeToWrite,
final ByteRange rangeToRead,
diff --git a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java
index edee430e7f0b0..33a1cb96eaeb7 100644
--- a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java
+++ b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java
@@ -375,6 +375,22 @@ public int read(ByteBuffer dst, int position) throws IOException {
return bytesRead;
}
+ /**
+ * Returns a read-only ByteBuffer slice of the memory-mapped region,
+ * or {@code null} if not memory-mapped.
+ *
+ * @param position the starting position within the region, must be non-negative
+ * @param length the number of bytes, {@code position + length} must not exceed the region size
+ * @throws IllegalArgumentException if the position/length are out of bounds
+ */
+ public ByteBuffer byteBufferSlice(int position, int length) {
+ if (mmap) {
+ checkOffsets(position, length);
+ return mappedByteBuffer.buffer().slice(position, length).asReadOnlyBuffer();
+ }
+ return null;
+ }
+
@SuppressForbidden(reason = "Use positional writes on purpose")
public int write(ByteBuffer src, int position) throws IOException {
// check if writes are page size aligned for optimal performance
diff --git a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java
index 7c61da5475fd6..eb5f7959e11ae 100644
--- a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java
+++ b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java
@@ -2280,6 +2280,333 @@ public void fillCacheRange(
}
}
+ // Verifies that withByteBufferSlice returns false before data is populated, and provides
+ // a readable byte buffer with correct content after population. Single region of size(10), file size(8).
+ public void testWithByteBufferSlice() throws Exception {
+ final int regionSize = (int) size(10);
+ final long fileLength = size(8); // fits in a single region
+ Settings settings = Settings.builder()
+ .put(NODE_NAME_SETTING.getKey(), "node")
+ .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(50)).getStringRep())
+ .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(regionSize).getStringRep())
+ .put(SharedBlobCacheService.SHARED_CACHE_MMAP.getKey(), true)
+ .put("path.home", createTempDir())
+ .build();
+ final DeterministicTaskQueue taskQueue = new DeterministicTaskQueue();
+ ExecutorService ioExecutor = Executors.newCachedThreadPool();
+ try (
+ NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings));
+ var cacheService = new SharedBlobCacheService(
+ environment,
+ settings,
+ taskQueue.getThreadPool(),
+ ioExecutor,
+ BlobCacheMetrics.NOOP
+ )
+ ) {
+ final var cacheKey = generateCacheKey();
+ SharedBlobCacheService.CacheFile cacheFile = cacheService.getCacheFile(
+ cacheKey,
+ fileLength,
+ SharedBlobCacheService.CacheMissHandler.NOOP
+ );
+
+ // before populating, withByteBufferSlice should return false (data not available)
+ assertFalse(cacheFile.withByteBufferSlice(0, 100, slice -> fail("should not be invoked")));
+
+ // populate the cache with known data
+ byte[] testData = randomByteArrayOfLength((int) fileLength);
+ ByteBuffer writeBuffer = ByteBuffer.allocate(SharedBytes.PAGE_SIZE);
+ final int bytesRead = cacheFile.populateAndRead(
+ ByteRange.of(0L, fileLength),
+ ByteRange.of(0L, fileLength),
+ (channel, pos, relativePos, len) -> len,
+ (channel, channelPos, streamFactory, relativePos, len, progressUpdater, completionListener) -> {
+ SharedBytes.copyToCacheFileAligned(
+ channel,
+ new java.io.ByteArrayInputStream(testData, relativePos, len),
+ channelPos,
+ relativePos,
+ len,
+ progressUpdater,
+ writeBuffer.clear()
+ );
+ ActionListener.completeWith(completionListener, () -> null);
+ },
+ "test"
+ );
+ assertThat(bytesRead, equalTo((int) fileLength));
+
+ // now withByteBufferSlice should provide a valid slice
+ int sliceOffset = randomIntBetween(0, (int) fileLength / 2);
+ int sliceLength = randomIntBetween(1, (int) fileLength - sliceOffset);
+ boolean available = cacheFile.withByteBufferSlice(sliceOffset, sliceLength, slice -> {
+ assertTrue(slice.isReadOnly());
+ assertEquals(sliceLength, slice.remaining());
+ byte[] sliceData = new byte[sliceLength];
+ slice.get(sliceData);
+ for (int i = 0; i < sliceLength; i++) {
+ assertEquals(testData[sliceOffset + i], sliceData[i]);
+ }
+ });
+ assertTrue(available);
+ }
+ ioExecutor.shutdown();
+ }
+
+ // Verifies that the byte buffer ref held during the callback prevents the region from being
+ // evicted. 2 regions of size(10), file size(8); eviction pressure is applied inside the callback.
+ public void testWithByteBufferSlicePreventsEviction() throws Exception {
+ final int regionSize = (int) size(10);
+ Settings settings = Settings.builder()
+ .put(NODE_NAME_SETTING.getKey(), "node")
+ .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(20)).getStringRep())
+ .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(regionSize).getStringRep())
+ .put(SharedBlobCacheService.SHARED_CACHE_MMAP.getKey(), true)
+ .put("path.home", createTempDir())
+ .build();
+ final DeterministicTaskQueue taskQueue = new DeterministicTaskQueue();
+ ExecutorService ioExecutor = Executors.newCachedThreadPool();
+ try (
+ NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings));
+ var cacheService = new SharedBlobCacheService(
+ environment,
+ settings,
+ taskQueue.getThreadPool(),
+ ioExecutor,
+ BlobCacheMetrics.NOOP
+ )
+ ) {
+ assertEquals(2, cacheService.freeRegionCount());
+
+ // populate region 0 with known data for cacheKey1
+ final long fileLength = size(8); // fits in one region
+ final var cacheKey1 = generateCacheKey();
+ SharedBlobCacheService.CacheFile cacheFile1 = cacheService.getCacheFile(
+ cacheKey1,
+ fileLength,
+ SharedBlobCacheService.CacheMissHandler.NOOP
+ );
+ byte[] testData = randomByteArrayOfLength((int) fileLength);
+ ByteBuffer writeBuffer = ByteBuffer.allocate(SharedBytes.PAGE_SIZE);
+ cacheFile1.populateAndRead(
+ ByteRange.of(0L, fileLength),
+ ByteRange.of(0L, fileLength),
+ (channel, pos, relativePos, len) -> len,
+ (channel, channelPos, streamFactory, relativePos, len, progressUpdater, completionListener) -> {
+ SharedBytes.copyToCacheFileAligned(
+ channel,
+ new java.io.ByteArrayInputStream(testData, relativePos, len),
+ channelPos,
+ relativePos,
+ len,
+ progressUpdater,
+ writeBuffer.clear()
+ );
+ ActionListener.completeWith(completionListener, () -> null);
+ },
+ "test"
+ );
+
+ // inside the callback, the ref is held — eviction should not reclaim the region
+ boolean available = cacheFile1.withByteBufferSlice(0, (int) fileLength, slice -> {
+ // fill the remaining region with a different key, using up all free regions
+ final var cacheKey2 = generateCacheKey();
+ cacheService.get(cacheKey2, fileLength, 0);
+
+ // now all regions are used; requesting yet another key triggers eviction pressure
+ final var cacheKey3 = generateCacheKey();
+ cacheService.get(cacheKey3, fileLength, 0);
+ taskQueue.runAllRunnableTasks();
+
+ // the buffer should still contain the original data (region not evicted while ref held)
+ byte[] readBack = new byte[(int) fileLength];
+ slice.get(readBack);
+ assertArrayEquals(testData, readBack);
+ });
+ assertTrue(available);
+
+ }
+ ioExecutor.shutdown();
+ }
+
+ // Verifies that withByteBufferSlice returns false and the callback is not invoked after a
+ // region has been evicted. 2 regions of size(10), file size(8); eviction forced by cache pressure.
+ public void testWithByteBufferSliceReturnsFalseAfterEviction() throws Exception {
+ final int regionSize = (int) size(10);
+ Settings settings = Settings.builder()
+ .put(NODE_NAME_SETTING.getKey(), "node")
+ .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(20)).getStringRep())
+ .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(regionSize).getStringRep())
+ .put(SharedBlobCacheService.SHARED_CACHE_MMAP.getKey(), true)
+ .put("path.home", createTempDir())
+ .build();
+ final DeterministicTaskQueue taskQueue = new DeterministicTaskQueue();
+ try (
+ NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings));
+ var cacheService = new SharedBlobCacheService(
+ environment,
+ settings,
+ taskQueue.getThreadPool(),
+ EsExecutors.DIRECT_EXECUTOR_SERVICE,
+ BlobCacheMetrics.NOOP
+ )
+ ) {
+ assertEquals(2, cacheService.freeRegionCount());
+
+ final long fileLength = size(8); // fits in one region
+ final var cacheKey1 = generateCacheKey();
+ SharedBlobCacheService.CacheFile cacheFile1 = cacheService.getCacheFile(
+ cacheKey1,
+ fileLength,
+ SharedBlobCacheService.CacheMissHandler.NOOP
+ );
+
+ // populate the region
+ byte[] testData = randomByteArrayOfLength((int) fileLength);
+ ByteBuffer writeBuffer = ByteBuffer.allocate(SharedBytes.PAGE_SIZE);
+ cacheFile1.populateAndRead(
+ ByteRange.of(0L, fileLength),
+ ByteRange.of(0L, fileLength),
+ (channel, pos, relativePos, len) -> len,
+ (channel, channelPos, streamFactory, relativePos, len, progressUpdater, completionListener) -> {
+ SharedBytes.copyToCacheFileAligned(
+ channel,
+ new java.io.ByteArrayInputStream(testData, relativePos, len),
+ channelPos,
+ relativePos,
+ len,
+ progressUpdater,
+ writeBuffer.clear()
+ );
+ ActionListener.completeWith(completionListener, () -> null);
+ },
+ "test"
+ );
+
+ // confirm the slice is accessible before eviction
+ assertTrue(cacheFile1.withByteBufferSlice(0, (int) fileLength, slice -> {}));
+
+ // fill the second region, then request a third key to force eviction of cacheKey1's region
+ cacheService.get(generateCacheKey(), fileLength, 0);
+ cacheService.get(generateCacheKey(), fileLength, 0);
+ taskQueue.runAllRunnableTasks();
+
+ // after eviction the action must not be invoked and the method must return false
+ boolean available = cacheFile1.withByteBufferSlice(
+ 0,
+ (int) fileLength,
+ slice -> { fail("action should not be invoked after eviction"); }
+ );
+ assertFalse(available);
+ }
+ }
+
+ // Verifies that withByteBufferSlice returns false when the requested range spans multiple
+ // regions. Regions of size(10), file size(25) spanning 3 regions; slice straddles the boundary.
+ public void testWithByteBufferSliceCrossRegionReturnsFalse() throws Exception {
+ final int regionSize = (int) size(10);
+ final long fileLength = size(25); // spans 3 regions
+ Settings settings = Settings.builder()
+ .put(NODE_NAME_SETTING.getKey(), "node")
+ .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(50)).getStringRep())
+ .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(regionSize).getStringRep())
+ .put(SharedBlobCacheService.SHARED_CACHE_MMAP.getKey(), true)
+ .put("path.home", createTempDir())
+ .build();
+ final DeterministicTaskQueue taskQueue = new DeterministicTaskQueue();
+ try (
+ NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings));
+ var cacheService = new SharedBlobCacheService(
+ environment,
+ settings,
+ taskQueue.getThreadPool(),
+ taskQueue.getThreadPool().executor(ThreadPool.Names.GENERIC),
+ BlobCacheMetrics.NOOP
+ )
+ ) {
+ final var cacheKey = generateCacheKey();
+ SharedBlobCacheService.CacheFile cacheFile = cacheService.getCacheFile(
+ cacheKey,
+ fileLength,
+ SharedBlobCacheService.CacheMissHandler.NOOP
+ );
+
+ // request a slice that spans the region boundary (region 0 -> region 1)
+ // region 0 covers [0, regionSize), region 1 covers [regionSize, 2*regionSize)
+ int crossBoundaryOffset = regionSize - 100;
+ int crossBoundaryLength = 200; // crosses into region 1
+ boolean available = cacheFile.withByteBufferSlice(crossBoundaryOffset, crossBoundaryLength, slice -> {
+ fail("action should not be invoked for cross-region slice");
+ });
+ assertFalse(available);
+ }
+ }
+
+ // Verifies that withByteBufferSlice returns false when mmap is disabled, even after the
+ // region has been fully populated. Single region of size(10), file size(8), mmap=false.
+ public void testWithByteBufferSliceNoMmapReturnsFalse() throws Exception {
+ final int regionSize = (int) size(10);
+ final long fileLength = size(8);
+ Settings settings = Settings.builder()
+ .put(NODE_NAME_SETTING.getKey(), "node")
+ .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(50)).getStringRep())
+ .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(regionSize).getStringRep())
+ .put(SharedBlobCacheService.SHARED_CACHE_MMAP.getKey(), false)
+ .put("path.home", createTempDir())
+ .build();
+ final DeterministicTaskQueue taskQueue = new DeterministicTaskQueue();
+ ExecutorService ioExecutor = Executors.newCachedThreadPool();
+ try (
+ NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings));
+ var cacheService = new SharedBlobCacheService(
+ environment,
+ settings,
+ taskQueue.getThreadPool(),
+ ioExecutor,
+ BlobCacheMetrics.NOOP
+ )
+ ) {
+ final var cacheKey = generateCacheKey();
+ SharedBlobCacheService.CacheFile cacheFile = cacheService.getCacheFile(
+ cacheKey,
+ fileLength,
+ SharedBlobCacheService.CacheMissHandler.NOOP
+ );
+
+ // populate the cache
+ byte[] testData = randomByteArrayOfLength((int) fileLength);
+ ByteBuffer writeBuffer = ByteBuffer.allocate(SharedBytes.PAGE_SIZE);
+ cacheFile.populateAndRead(
+ ByteRange.of(0L, fileLength),
+ ByteRange.of(0L, fileLength),
+ (channel, pos, relativePos, len) -> len,
+ (channel, channelPos, streamFactory, relativePos, len, progressUpdater, completionListener) -> {
+ SharedBytes.copyToCacheFileAligned(
+ channel,
+ new java.io.ByteArrayInputStream(testData, relativePos, len),
+ channelPos,
+ relativePos,
+ len,
+ progressUpdater,
+ writeBuffer.clear()
+ );
+ ActionListener.completeWith(completionListener, () -> null);
+ },
+ "test"
+ );
+
+ // without mmap, withByteBufferSlice should return false even with data populated
+ boolean available = cacheFile.withByteBufferSlice(
+ 0,
+ 100,
+ slice -> { fail("action should not be invoked when mmap is not enabled"); }
+ );
+ assertFalse(available);
+ }
+ ioExecutor.shutdown();
+ }
+
private record TestCacheKey(ShardId shardId, String file) implements SharedBlobCacheService.KeyBase {}
private static TestCacheKey randomTestCacheKey(ShardId shardId) {
diff --git a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBytesTests.java b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBytesTests.java
index 1f26935e24c83..e4bfef7336cbf 100644
--- a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBytesTests.java
+++ b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBytesTests.java
@@ -8,6 +8,7 @@
package org.elasticsearch.blobcache.shared;
import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Assertions;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.PathUtils;
import org.elasticsearch.env.Environment;
@@ -22,6 +23,7 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.nullValue;
public class SharedBytesTests extends ESTestCase {
@@ -122,6 +124,105 @@ public void testCopyAllWith0Padding() throws Exception {
}
}
+ // Verifies that byteBufferSlice returns a read-only buffer with correct content when mmap
+ // is enabled. Randomized region count (1-4) and region size (1-16 pages).
+ public void testByteBufferSliceMmap() throws Exception {
+ int regions = randomIntBetween(1, 4);
+ int regionSize = randomIntBetween(1, 16) * SharedBytes.PAGE_SIZE;
+ var nodeSettings = Settings.builder()
+ .put(Node.NODE_NAME_SETTING.getKey(), "node")
+ .put("path.home", createTempDir())
+ .putList(Environment.PATH_DATA_SETTING.getKey(), createTempDir().toString())
+ .build();
+ SharedBytes sharedBytes = null;
+ try (var nodeEnv = new NodeEnvironment(nodeSettings, TestEnvironment.newEnvironment(nodeSettings))) {
+ // mmap=true
+ sharedBytes = new SharedBytes(regions, regionSize, nodeEnv, ignored -> {}, ignored -> {}, true);
+ int region = randomIntBetween(0, regions - 1);
+ byte[] randomData = randomByteArrayOfLength(regionSize);
+ ByteBuffer tempBuffer = ByteBuffer.allocate(regionSize);
+ SharedBytes.copyToCacheFileAligned(
+ sharedBytes.getFileChannel(region),
+ new ByteArrayInputStream(randomData),
+ 0,
+ writtenBytesCount -> {},
+ tempBuffer
+ );
+
+ SharedBytes.IO io = sharedBytes.getFileChannel(region);
+
+ // byteBufferSlice returns a non-null read-only buffer with correct data
+ int sliceOffset = randomIntBetween(0, regionSize / 2);
+ int sliceLength = randomIntBetween(1, regionSize - sliceOffset);
+ ByteBuffer slice = io.byteBufferSlice(sliceOffset, sliceLength);
+ assertNotNull(slice);
+ assertTrue(slice.isReadOnly());
+ assertEquals(sliceLength, slice.remaining());
+ byte[] sliceData = new byte[sliceLength];
+ slice.get(sliceData);
+ for (int i = 0; i < sliceLength; i++) {
+ assertEquals(randomData[sliceOffset + i], sliceData[i]);
+ }
+ } finally {
+ if (sharedBytes != null) {
+ sharedBytes.decRef();
+ }
+ }
+ }
+
+ // Verifies that byteBufferSlice returns null when mmap is disabled.
+ // Randomized region count (1-4) and region size (1-16 pages), mmap=false.
+ public void testByteBufferSliceNoMmap() throws Exception {
+ int regions = randomIntBetween(1, 4);
+ int regionSize = randomIntBetween(1, 16) * SharedBytes.PAGE_SIZE;
+ var nodeSettings = Settings.builder()
+ .put(Node.NODE_NAME_SETTING.getKey(), "node")
+ .put("path.home", createTempDir())
+ .putList(Environment.PATH_DATA_SETTING.getKey(), createTempDir().toString())
+ .build();
+ SharedBytes sharedBytes = null;
+ try (var nodeEnv = new NodeEnvironment(nodeSettings, TestEnvironment.newEnvironment(nodeSettings))) {
+ // mmap=false
+ sharedBytes = new SharedBytes(regions, regionSize, nodeEnv, ignored -> {}, ignored -> {}, false);
+ int region = randomIntBetween(0, regions - 1);
+ SharedBytes.IO io = sharedBytes.getFileChannel(region);
+
+ // byteBufferSlice returns null when not mmap'd
+ assertThat(io.byteBufferSlice(0, regionSize), nullValue());
+ } finally {
+ if (sharedBytes != null) {
+ sharedBytes.decRef();
+ }
+ }
+ }
+
+ // Verifies that byteBufferSlice rejects out-of-bounds requests: offset+length exceeding
+ // region size, and negative offset. Single region of 4 pages, mmap=true.
+ public void testByteBufferSliceBoundsCheck() throws Exception {
+ int regions = 1;
+ int regionSize = 4 * SharedBytes.PAGE_SIZE;
+ var nodeSettings = Settings.builder()
+ .put(Node.NODE_NAME_SETTING.getKey(), "node")
+ .put("path.home", createTempDir())
+ .putList(Environment.PATH_DATA_SETTING.getKey(), createTempDir().toString())
+ .build();
+ SharedBytes sharedBytes = null;
+ try (var nodeEnv = new NodeEnvironment(nodeSettings, TestEnvironment.newEnvironment(nodeSettings))) {
+ sharedBytes = new SharedBytes(regions, regionSize, nodeEnv, ignored -> {}, ignored -> {}, true);
+ SharedBytes.IO io = sharedBytes.getFileChannel(0);
+
+ var expectedType = Assertions.ENABLED ? AssertionError.class : IllegalArgumentException.class;
+ // position + length exceeds region size
+ expectThrows(expectedType, () -> io.byteBufferSlice(regionSize - 10, 20));
+ // negative position
+ expectThrows(expectedType, () -> io.byteBufferSlice(-1, 10));
+ } finally {
+ if (sharedBytes != null) {
+ sharedBytes.decRef();
+ }
+ }
+ }
+
/**
* Test that mmap'd SharedBytes instances release their mapped memory on close, so that the OS
* can reclaim disk space immediately. Without proper unmapping, each iteration leaks the cache
diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java
index ab8ae11a56ca8..04620a753c4fc 100644
--- a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java
+++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java
@@ -17,6 +17,8 @@
import org.elasticsearch.blobcache.common.ByteRange;
import org.elasticsearch.blobcache.shared.SharedBlobCacheService;
import org.elasticsearch.blobcache.shared.SharedBytes;
+import org.elasticsearch.core.CheckedConsumer;
+import org.elasticsearch.core.DirectAccessInput;
import org.elasticsearch.index.snapshots.blobstore.BlobStoreIndexShardSnapshot.FileInfo;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots;
@@ -28,7 +30,7 @@
import java.io.InputStream;
import java.nio.ByteBuffer;
-public final class FrozenIndexInput extends MetadataCachingIndexInput {
+public final class FrozenIndexInput extends MetadataCachingIndexInput implements DirectAccessInput {
private static final Logger logger = LogManager.getLogger(FrozenIndexInput.class);
@@ -105,6 +107,11 @@ private FrozenIndexInput(FrozenIndexInput input) {
this.cacheFile = input.cacheFile.copy();
}
+ @Override
+ public boolean withByteBufferSlice(long offset, long length, CheckedConsumer action) throws IOException {
+ return cacheFile.withByteBufferSlice(offset + this.offset, Math.toIntExact(length), action);
+ }
+
@Override
protected void readWithoutBlobCache(ByteBuffer b) throws Exception {
final long position = getAbsolutePosition();
diff --git a/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectoryFactory.java b/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectoryFactory.java
index a3fba29958313..e13ac1639f2b1 100644
--- a/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectoryFactory.java
+++ b/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectoryFactory.java
@@ -85,11 +85,13 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
/**
@@ -119,6 +121,196 @@ public static Directory newDirectory(Path path) {
return new WriteOnceSnapshotDirectory(path);
}
+ /**
+ * Returns a read-only {@link Directory} backed by searchable snapshot infrastructure,
+ * wrapping an existing Lucene index on disk. The original index files are read directly
+ * via {@link FsBlobContainer} (zero-copy, no heap buffering).
+ *
+ * @param indexPath path to the existing Lucene index directory
+ * @param workPath scratch directory for cache and node environment files
+ */
+ public static Directory newDirectoryFromIndex(Path indexPath, Path workPath) throws IOException {
+ return new DiskBackedSnapshotDirectory(indexPath, workPath);
+ }
+
+ // -- DiskBackedSnapshotDirectory --
+
+ static class DiskBackedSnapshotDirectory extends Directory {
+ private final SearchableSnapshotDirectory delegate;
+
+ // infrastructure
+ private final ThreadPool threadPool;
+ private final CacheService cacheService;
+ private final ClusterService clusterService;
+ private final SharedBlobCacheService sharedBlobCacheService;
+ private final NodeEnvironment nodeEnvironment;
+
+ DiskBackedSnapshotDirectory(Path indexPath, Path workPath) throws IOException {
+ boolean success = false;
+ ThreadPool tp = null;
+ NodeEnvironment ne = null;
+ ClusterService cs = null;
+ CacheService cas = null;
+ SharedBlobCacheService sbcs = null;
+ try {
+ // Enumerate index files, build snapshot metadata
+ List fileInfos = new ArrayList<>();
+ long totalSize = 0;
+ try (FSDirectory fsDir = FSDirectory.open(indexPath)) {
+ for (String file : fsDir.listAll()) {
+ if (file.equals("write.lock")) continue;
+ long fileLength = fsDir.fileLength(file);
+ totalSize += fileLength;
+ StoreFileMetadata metadata = new StoreFileMetadata(
+ file,
+ fileLength,
+ "0",
+ IndexVersion.current().luceneVersion().toString()
+ );
+ fileInfos.add(new BlobStoreIndexShardSnapshot.FileInfo(file, metadata, ByteSizeValue.ofBytes(fileLength + 1)));
+ }
+ }
+
+ // FsBlobContainer reads directly from the index directory (blob name = file name)
+ FsBlobStore blobStore = new FsBlobStore(8192, indexPath, true);
+ BlobContainer blobContainer = new FsBlobContainer(blobStore, BlobPath.EMPTY, indexPath);
+
+ BlobStoreIndexShardSnapshot snapshot = new BlobStoreIndexShardSnapshot("snapshotId", fileInfos, 0L, 0L, 0, 0L);
+
+ long cacheSize = totalSize;
+
+ SnapshotId snapshotId = new SnapshotId("_name", "_uuid");
+ IndexId indexId = new IndexId("_name", "_uuid");
+ ShardId shardId = new ShardId("_name", "_uuid", 0);
+ Path topDir = workPath.resolve(shardId.getIndex().getUUID());
+ Path shardDir = topDir.resolve(Integer.toString(shardId.getId()));
+ ShardPath shardPath = new ShardPath(false, shardDir, shardDir, shardId);
+ Path cacheDir = Files.createDirectories(CacheService.resolveSnapshotCache(shardDir).resolve(snapshotId.getUUID()));
+
+ tp = new SimpleThreadPool("tp", SearchableSnapshots.executorBuilders(Settings.EMPTY));
+ ne = newNodeEnvironment(Settings.EMPTY, workPath, cacheSize);
+ cs = createClusterService(tp, clusterSettings());
+
+ cas = new CacheService(Settings.EMPTY, cs, tp, new PersistentCache(ne));
+ cas.start();
+ sbcs = defaultFrozenCacheService(tp, ne, workPath, cacheSize);
+
+ SearchableSnapshotDirectory dir = new SearchableSnapshotDirectory(
+ () -> blobContainer,
+ () -> snapshot,
+ new NoopBlobStoreCacheService(tp),
+ "_repo",
+ snapshotId,
+ indexId,
+ shardId,
+ buildIndexSettings(),
+ () -> 0L,
+ cas,
+ cacheDir,
+ shardPath,
+ tp,
+ sbcs
+ );
+
+ // loadSnapshot asserts ThreadPool.assertCurrentThreadPool(GENERIC)
+ SearchableSnapshotRecoveryState recoveryState = createRecoveryState(true);
+ PlainActionFuture f = new PlainActionFuture<>();
+ CountDownLatch latch = new CountDownLatch(1);
+ tp.generic().execute(() -> {
+ try {
+ dir.loadSnapshot(recoveryState, () -> false, f);
+ } finally {
+ latch.countDown();
+ }
+ });
+ latch.await();
+ f.get();
+
+ this.delegate = dir;
+ this.threadPool = tp;
+ this.nodeEnvironment = ne;
+ this.clusterService = cs;
+ this.cacheService = cas;
+ this.sharedBlobCacheService = sbcs;
+ success = true;
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new IOException("Interrupted while loading snapshot", e);
+ } catch (Exception e) {
+ throw new IOException("Failed to create searchable snapshot directory", e);
+ } finally {
+ if (success == false) {
+ IOUtils.closeWhileHandlingException(sbcs, cs, cas, ne);
+ if (tp != null) {
+ ThreadPool.terminate(tp, 10, TimeUnit.SECONDS);
+ }
+ }
+ }
+ }
+
+ @Override
+ public String[] listAll() throws IOException {
+ return delegate.listAll();
+ }
+
+ @Override
+ public void deleteFile(String name) {
+ throw new UnsupportedOperationException("read-only directory");
+ }
+
+ @Override
+ public long fileLength(String name) throws IOException {
+ return delegate.fileLength(name);
+ }
+
+ @Override
+ public IndexOutput createOutput(String name, IOContext context) {
+ throw new UnsupportedOperationException("read-only directory");
+ }
+
+ @Override
+ public IndexOutput createTempOutput(String prefix, String suffix, IOContext context) {
+ throw new UnsupportedOperationException("read-only directory");
+ }
+
+ @Override
+ public void sync(Collection names) {
+ throw new UnsupportedOperationException("read-only directory");
+ }
+
+ @Override
+ public void syncMetaData() {
+ throw new UnsupportedOperationException("read-only directory");
+ }
+
+ @Override
+ public void rename(String source, String dest) {
+ throw new UnsupportedOperationException("read-only directory");
+ }
+
+ @Override
+ public IndexInput openInput(String name, IOContext context) throws IOException {
+ return delegate.openInput(name, context);
+ }
+
+ @Override
+ public Lock obtainLock(String name) {
+ throw new UnsupportedOperationException("read-only directory");
+ }
+
+ @Override
+ public Set getPendingDeletions() {
+ return Set.of();
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOUtils.closeWhileHandlingException(sharedBlobCacheService, clusterService, cacheService, nodeEnvironment);
+ delegate.close();
+ ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS);
+ }
+ }
+
// -- WriteOnceSnapshotDirectory --
static class WriteOnceSnapshotDirectory extends Directory {
@@ -286,7 +478,7 @@ private void materializeSnapshotImpl() throws IOException {
);
// load the snapshot so it's ready for reads
- SearchableSnapshotRecoveryState recoveryState = createRecoveryState(false);
+ SearchableSnapshotRecoveryState recoveryState = createRecoveryState(true);
final PlainActionFuture f = new PlainActionFuture<>();
delegate.loadSnapshot(recoveryState, () -> false, f);
try {