diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java
index 83beae607dc5..8cef8d29ce74 100644
--- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java
+++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java
@@ -31,7 +31,7 @@
import org.apache.lucene.index.SegmentWriteState;
/**
- * A Faiss-based format to create and search vector indexes, using {@link LibFaissC} to interact
+ * A Faiss-based format to create and search vector indexes, using {@link FaissLibrary} to interact
* with the native library.
*
*
The Faiss index is configured using its flexible indexMap;
- private final Arena arena;
+ private final Map indexMap;
private boolean closed;
public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader)
throws IOException {
this.rawVectorsReader = rawVectorsReader;
this.indexMap = new HashMap<>();
- this.arena = Arena.ofShared();
- this.closed = false;
List fieldMetaList = new ArrayList<>();
String metaFileName =
@@ -125,9 +115,11 @@ public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVector
CodecUtil.retrieveChecksum(data);
for (FieldMeta fieldMeta : fieldMetaList) {
- if (indexMap.put(fieldMeta.fieldInfo.name, loadField(data, arena, fieldMeta)) != null) {
- throw new CorruptIndexException("Duplicate field: " + fieldMeta.fieldInfo.name, meta);
+ if (indexMap.containsKey(fieldMeta.name)) {
+ throw new CorruptIndexException("Duplicate field: " + fieldMeta.name, meta);
}
+ IndexInput indexInput = data.slice(fieldMeta.name, fieldMeta.offset, fieldMeta.length);
+ indexMap.put(fieldMeta.name, FaissLibrary.INSTANCE.readIndex(indexInput));
}
} catch (Throwable t) {
IOUtils.closeWhileSuppressingExceptions(t, this);
@@ -150,21 +142,7 @@ private static FieldMeta parseNextField(IndexInput meta, SegmentReadState state)
long dataOffset = meta.readLong();
long dataLength = meta.readLong();
- return new FieldMeta(fieldInfo, dataOffset, dataLength);
- }
-
- @SuppressWarnings("restricted") // TODO: encapsulate the unsafeness into the LibFaissC
- private static IndexEntry loadField(IndexInput data, Arena arena, FieldMeta fieldMeta)
- throws IOException {
- int ioFlags = FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY;
-
- // Read index into memory
- MemorySegment indexPointer =
- indexRead(data.slice(fieldMeta.fieldInfo.name, fieldMeta.offset, fieldMeta.length), ioFlags)
- // Ensure timely cleanup
- .reinterpret(arena, LibFaissC::freeIndex);
-
- return new IndexEntry(indexPointer, fieldMeta.fieldInfo.getVectorSimilarityFunction());
+ return new FieldMeta(fieldInfo.name, dataOffset, dataLength);
}
@Override
@@ -188,9 +166,9 @@ public ByteVectorValues getByteVectorValues(String field) {
@Override
public void search(String field, float[] vector, KnnCollector knnCollector, Bits acceptDocs) {
- IndexEntry entry = indexMap.get(field);
- if (entry != null) {
- indexSearch(entry.indexPointer, entry.function, vector, knnCollector, acceptDocs);
+ FaissLibrary.Index index = indexMap.get(field);
+ if (index != null) {
+ index.search(vector, knnCollector, acceptDocs);
}
}
@@ -210,12 +188,16 @@ public Map getOffHeapByteSize(FieldInfo fieldInfo) {
@Override
public void close() throws IOException {
if (closed == false) {
+ // Close all indexes
+ for (FaissLibrary.Index index : indexMap.values()) {
+ index.close();
+ }
+ indexMap.clear();
+
+ IOUtils.close(rawVectorsReader, data);
closed = true;
- IOUtils.close(rawVectorsReader, arena::close, data, indexMap::clear);
}
}
- private record FieldMeta(FieldInfo fieldInfo, long offset, long length) {}
-
- private record IndexEntry(MemorySegment indexPointer, VectorSimilarityFunction function) {}
+ private record FieldMeta(String name, long offset, long length) {}
}
diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java
index a81864a9c25e..f41986d8e8ce 100644
--- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java
+++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java
@@ -21,14 +21,8 @@
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME;
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION;
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT;
-import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_MMAP;
-import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_READ_ONLY;
-import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.createIndex;
-import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexWrite;
import java.io.IOException;
-import java.lang.foreign.Arena;
-import java.lang.foreign.MemorySegment;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -43,7 +37,6 @@
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
-import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
@@ -154,26 +147,23 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
}
}
- @SuppressWarnings("restricted") // TODO: encapsulate the unsafeness into the LibFaissC
private void writeFloatField(
FieldInfo fieldInfo, FloatVectorValues floatVectorValues, IntToIntFunction oldToNewDocId)
throws IOException {
int number = fieldInfo.number;
meta.writeInt(number);
- // Write index to temp file and deallocate from memory
- try (Arena temp = Arena.ofConfined()) {
- VectorSimilarityFunction function = fieldInfo.getVectorSimilarityFunction();
- MemorySegment indexPointer =
- createIndex(description, indexParams, function, floatVectorValues, oldToNewDocId)
- // Ensure timely cleanup
- .reinterpret(temp, LibFaissC::freeIndex);
-
- int ioFlags = FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY;
+ try (FaissLibrary.Index index =
+ FaissLibrary.INSTANCE.createIndex(
+ description,
+ indexParams,
+ fieldInfo.getVectorSimilarityFunction(),
+ floatVectorValues,
+ oldToNewDocId)) {
// Write index
long dataOffset = data.getFilePointer();
- indexWrite(indexPointer, data, ioFlags);
+ index.write(data);
long dataLength = data.getFilePointer() - dataOffset;
meta.writeLong(dataOffset);
@@ -233,7 +223,7 @@ public int size() {
@Override
public FloatVectorValues copy() {
- return new BufferedFloatVectorValues(floats, dimension, docIdSet);
+ throw new AssertionError("Should not be called");
}
@Override
diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java
new file mode 100644
index 000000000000..e7837692222d
--- /dev/null
+++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.sandbox.codecs.faiss;
+
+import java.io.Closeable;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.KnnCollector;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.hnsw.IntToIntFunction;
+
+/**
+ * Minimal interface to create and query Faiss indexes.
+ *
+ * @lucene.experimental
+ */
+interface FaissLibrary {
+ FaissLibrary INSTANCE = new FaissLibraryNativeImpl();
+
+ // TODO: Use SIMD version at runtime. The "faiss_c" library is linked to the main "faiss" library,
+ // which does not use SIMD instructions. However, there are SIMD versions of "faiss" (like
+ // "faiss_avx2", "faiss_avx512", "faiss_sve", etc.) available, which can be used by changing the
+ // dependencies of "faiss_c" using the "patchelf" utility. Figure out how to do this dynamically,
+ // or via modifications to upstream Faiss.
+ String NAME = "faiss_c";
+ String VERSION = "1.11.0";
+
+ interface Index extends Closeable {
+ void search(float[] query, KnnCollector knnCollector, Bits acceptDocs);
+
+ void write(IndexOutput output);
+ }
+
+ Index createIndex(
+ String description,
+ String indexParams,
+ VectorSimilarityFunction function,
+ FloatVectorValues floatVectorValues,
+ IntToIntFunction oldToNewDocId);
+
+ Index readIndex(IndexInput input);
+}
diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java
new file mode 100644
index 000000000000..d72e65eca860
--- /dev/null
+++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.sandbox.codecs.faiss;
+
+import static java.lang.foreign.ValueLayout.ADDRESS;
+import static java.lang.foreign.ValueLayout.JAVA_BYTE;
+import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
+import static java.lang.foreign.ValueLayout.JAVA_LONG;
+import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.sandbox.codecs.faiss.FaissNativeWrapper.Exception.handleException;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.lang.foreign.Arena;
+import java.lang.foreign.FunctionDescriptor;
+import java.lang.foreign.Linker;
+import java.lang.foreign.MemorySegment;
+import java.lang.invoke.MethodHandle;
+import java.lang.invoke.MethodHandles;
+import java.lang.invoke.MethodType;
+import java.nio.ByteOrder;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.KnnCollector;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.FixedBitSet;
+import org.apache.lucene.util.hnsw.IntToIntFunction;
+
+/**
+ * A native implementation of {@link FaissLibrary} using {@link FaissNativeWrapper}.
+ *
+ * @lucene.experimental
+ */
+@SuppressWarnings("restricted") // uses unsafe calls
+final class FaissLibraryNativeImpl implements FaissLibrary {
+ private final FaissNativeWrapper wrapper;
+
+ FaissLibraryNativeImpl() {
+ this.wrapper = new FaissNativeWrapper();
+ }
+
+ private static MemorySegment getStub(
+ Arena arena, MethodHandle target, FunctionDescriptor descriptor) {
+ return Linker.nativeLinker().upcallStub(target, descriptor, arena);
+ }
+
+ private static final int BUFFER_SIZE = 256 * 1024 * 1024; // 256 MB
+
+ @SuppressWarnings("unused") // called using a MethodHandle
+ private static long writeBytes(
+ IndexOutput output, MemorySegment inputPointer, long itemSize, long numItems)
+ throws IOException {
+ long size = itemSize * numItems;
+ inputPointer = inputPointer.reinterpret(size);
+
+ if (size <= BUFFER_SIZE) { // simple case, avoid buffering
+ output.writeBytes(inputPointer.toArray(JAVA_BYTE), (int) size);
+ } else { // copy buffered number of bytes repeatedly
+ byte[] bytes = new byte[BUFFER_SIZE];
+ for (long offset = 0; offset < size; offset += BUFFER_SIZE) {
+ int length = (int) Math.min(size - offset, BUFFER_SIZE);
+ MemorySegment.copy(inputPointer, JAVA_BYTE, offset, bytes, 0, length);
+ output.writeBytes(bytes, length);
+ }
+ }
+ return numItems;
+ }
+
+ @SuppressWarnings("unused") // called using a MethodHandle
+ private static long readBytes(
+ IndexInput input, MemorySegment outputPointer, long itemSize, long numItems)
+ throws IOException {
+ long size = itemSize * numItems;
+ outputPointer = outputPointer.reinterpret(size);
+
+ if (size <= BUFFER_SIZE) { // simple case, avoid buffering
+ byte[] bytes = new byte[(int) size];
+ input.readBytes(bytes, 0, bytes.length);
+ MemorySegment.copy(bytes, 0, outputPointer, JAVA_BYTE, 0, bytes.length);
+ } else { // copy buffered number of bytes repeatedly
+ byte[] bytes = new byte[BUFFER_SIZE];
+ for (long offset = 0; offset < size; offset += BUFFER_SIZE) {
+ int length = (int) Math.min(size - offset, BUFFER_SIZE);
+ input.readBytes(bytes, 0, length);
+ MemorySegment.copy(bytes, 0, outputPointer, JAVA_BYTE, offset, length);
+ }
+ }
+ return numItems;
+ }
+
+ private static final MethodHandle WRITE_BYTES_HANDLE;
+ private static final MethodHandle READ_BYTES_HANDLE;
+
+ static {
+ try {
+ MethodHandles.Lookup lookup = MethodHandles.lookup();
+
+ WRITE_BYTES_HANDLE =
+ lookup.findStatic(
+ FaissLibraryNativeImpl.class,
+ "writeBytes",
+ MethodType.methodType(
+ long.class, IndexOutput.class, MemorySegment.class, long.class, long.class));
+
+ READ_BYTES_HANDLE =
+ lookup.findStatic(
+ FaissLibraryNativeImpl.class,
+ "readBytes",
+ MethodType.methodType(
+ long.class, IndexInput.class, MemorySegment.class, long.class, long.class));
+
+ } catch (NoSuchMethodException | IllegalAccessException e) {
+ throw new LinkageError(
+ "FaissLibraryNativeImpl reader / writer functions are missing or inaccessible", e);
+ }
+ }
+
+ private static final Map FUNCTION_TO_METRIC =
+ Map.of(
+ // Mapped from faiss/MetricType.h
+ DOT_PRODUCT, 0,
+ EUCLIDEAN, 1);
+
+ private static int functionToMetric(VectorSimilarityFunction function) {
+ Integer metric = FUNCTION_TO_METRIC.get(function);
+ if (metric == null) {
+ throw new UnsupportedOperationException("Similarity function not supported: " + function);
+ }
+ return metric;
+ }
+
+ // Invert FUNCTION_TO_METRIC
+ private static final Map METRIC_TO_FUNCTION =
+ FUNCTION_TO_METRIC.entrySet().stream()
+ .collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
+
+ private static VectorSimilarityFunction metricToFunction(int metric) {
+ VectorSimilarityFunction function = METRIC_TO_FUNCTION.get(metric);
+ if (function == null) {
+ throw new UnsupportedOperationException("Metric not supported: " + metric);
+ }
+ return function;
+ }
+
+ @Override
+ public FaissLibrary.Index createIndex(
+ String description,
+ String indexParams,
+ VectorSimilarityFunction function,
+ FloatVectorValues floatVectorValues,
+ IntToIntFunction oldToNewDocId) {
+
+ try (Arena temp = Arena.ofConfined()) {
+ int size = floatVectorValues.size();
+ int dimension = floatVectorValues.dimension();
+ int metric = functionToMetric(function);
+
+ // Create an index
+ MemorySegment pointer = temp.allocate(ADDRESS);
+ handleException(
+ wrapper.faiss_index_factory(pointer, dimension, temp.allocateFrom(description), metric));
+
+ MemorySegment indexPointer = pointer.get(ADDRESS, 0);
+
+ // Set index params
+ handleException(wrapper.faiss_ParameterSpace_new(pointer));
+ MemorySegment parameterSpacePointer =
+ pointer
+ .get(ADDRESS, 0)
+ // Ensure timely cleanup
+ .reinterpret(temp, wrapper::faiss_ParameterSpace_free);
+
+ handleException(
+ wrapper.faiss_ParameterSpace_set_index_parameters(
+ parameterSpacePointer, indexPointer, temp.allocateFrom(indexParams)));
+
+ // TODO: Improve memory usage (with a tradeoff in performance) by batched indexing, see:
+ // - https://github.com/opensearch-project/k-NN/issues/1506
+ // - https://github.com/opensearch-project/k-NN/issues/1938
+
+ // Allocate docs in native memory
+ MemorySegment docs = temp.allocate(JAVA_FLOAT, (long) size * dimension);
+ long docsOffset = 0;
+ long perDocByteSize = dimension * JAVA_FLOAT.byteSize();
+
+ // Allocate ids in native memory
+ MemorySegment ids = temp.allocate(JAVA_LONG, size);
+ int idsIndex = 0;
+
+ KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
+ for (int i = iterator.nextDoc(); i != NO_MORE_DOCS; i = iterator.nextDoc()) {
+ int id = oldToNewDocId.apply(i);
+ ids.setAtIndex(JAVA_LONG, idsIndex, id);
+ idsIndex++;
+
+ float[] vector = floatVectorValues.vectorValue(iterator.index());
+ MemorySegment.copy(vector, 0, docs, JAVA_FLOAT, docsOffset, vector.length);
+ docsOffset += perDocByteSize;
+ }
+
+ // Train index
+ int isTrained = wrapper.faiss_Index_is_trained(indexPointer);
+ if (isTrained == 0) {
+ handleException(wrapper.faiss_Index_train(indexPointer, size, docs));
+ }
+
+ // Add docs to index
+ handleException(wrapper.faiss_Index_add_with_ids(indexPointer, size, docs, ids));
+
+ return new Index(indexPointer);
+
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ // See flags defined in c_api/index_io_c.h
+ private static final int FAISS_IO_FLAG_MMAP = 1;
+ private static final int FAISS_IO_FLAG_READ_ONLY = 2;
+
+ @Override
+ public FaissLibrary.Index readIndex(IndexInput input) {
+ try (Arena temp = Arena.ofConfined()) {
+ MethodHandle readerHandle = READ_BYTES_HANDLE.bindTo(input);
+ MemorySegment readerStub =
+ getStub(
+ temp, readerHandle, FunctionDescriptor.of(JAVA_LONG, ADDRESS, JAVA_LONG, JAVA_LONG));
+
+ MemorySegment pointer = temp.allocate(ADDRESS);
+ handleException(wrapper.faiss_CustomIOReader_new(pointer, readerStub));
+ MemorySegment customIOReaderPointer =
+ pointer
+ .get(ADDRESS, 0)
+ // Ensure timely cleanup
+ .reinterpret(temp, wrapper::faiss_CustomIOReader_free);
+
+ // Read index
+ handleException(
+ wrapper.faiss_read_index_custom(
+ customIOReaderPointer, FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY, pointer));
+ MemorySegment indexPointer = pointer.get(ADDRESS, 0);
+
+ return new Index(indexPointer);
+ }
+ }
+
+ private class Index implements FaissLibrary.Index {
+ @FunctionalInterface
+ private interface FloatToFloatFunction {
+ float scale(float score);
+ }
+
+ private final Arena arena;
+ private final MemorySegment indexPointer;
+ private final FloatToFloatFunction scaler;
+ private boolean closed;
+
+ private Index(MemorySegment indexPointer) {
+ this.arena = Arena.ofShared();
+ this.indexPointer =
+ indexPointer
+ // Ensure timely cleanup
+ .reinterpret(arena, wrapper::faiss_Index_free);
+
+ // Get underlying function
+ int metricType = wrapper.faiss_Index_metric_type(indexPointer);
+ VectorSimilarityFunction function = metricToFunction(metricType);
+
+ // Scale Faiss distances to Lucene scores, see VectorSimilarityFunction.java
+ this.scaler =
+ switch (function) {
+ case DOT_PRODUCT ->
+ // distance in Faiss === dotProduct in Lucene
+ distance -> Math.max((1 + distance) / 2, 0);
+
+ case EUCLIDEAN ->
+ // distance in Faiss === squareDistance in Lucene
+ distance -> 1 / (1 + distance);
+
+ case COSINE, MAXIMUM_INNER_PRODUCT -> throw new AssertionError("Should not reach here");
+ };
+
+ this.closed = false;
+ }
+
+ @Override
+ public void close() {
+ if (closed == false) {
+ arena.close();
+ closed = true;
+ }
+ }
+
+ @Override
+ public void search(float[] query, KnnCollector knnCollector, Bits acceptDocs) {
+ try (Arena temp = Arena.ofConfined()) {
+ FixedBitSet fixedBitSet =
+ switch (acceptDocs) {
+ case null -> null;
+ case FixedBitSet bitSet -> bitSet;
+ // TODO: Add optimized case for SparseFixedBitSet
+ case Bits bits -> FixedBitSet.copyOf(bits);
+ };
+
+ // Allocate queries in native memory
+ MemorySegment queries = temp.allocateFrom(JAVA_FLOAT, query);
+
+ // Faiss knn search
+ int k = knnCollector.k();
+ MemorySegment distancesPointer = temp.allocate(JAVA_FLOAT, k);
+ MemorySegment idsPointer = temp.allocate(JAVA_LONG, k);
+
+ MemorySegment localIndex = indexPointer.reinterpret(temp, null);
+ if (fixedBitSet == null) {
+ // Search without runtime filters
+ handleException(
+ wrapper.faiss_Index_search(localIndex, 1, queries, k, distancesPointer, idsPointer));
+ } else {
+ MemorySegment pointer = temp.allocate(ADDRESS);
+
+ long[] bits = fixedBitSet.getBits();
+ MemorySegment nativeBits =
+ // Use LITTLE_ENDIAN to convert long[] -> uint8_t*
+ temp.allocateFrom(JAVA_LONG.withOrder(ByteOrder.LITTLE_ENDIAN), bits);
+
+ handleException(
+ wrapper.faiss_IDSelectorBitmap_new(pointer, fixedBitSet.length(), nativeBits));
+ MemorySegment idSelectorBitmapPointer =
+ pointer
+ .get(ADDRESS, 0)
+ // Ensure timely cleanup
+ .reinterpret(temp, wrapper::faiss_IDSelectorBitmap_free);
+
+ handleException(wrapper.faiss_SearchParameters_new(pointer, idSelectorBitmapPointer));
+ MemorySegment searchParametersPointer =
+ pointer
+ .get(ADDRESS, 0)
+ // Ensure timely cleanup
+ .reinterpret(temp, wrapper::faiss_SearchParameters_free);
+
+ // Search with runtime filters
+ handleException(
+ wrapper.faiss_Index_search_with_params(
+ localIndex,
+ 1,
+ queries,
+ k,
+ searchParametersPointer,
+ distancesPointer,
+ idsPointer));
+ }
+
+ // Record hits
+ for (int i = 0; i < k; i++) {
+ int id = (int) idsPointer.getAtIndex(JAVA_LONG, i);
+
+ // Not enough results
+ if (id == -1) {
+ break;
+ }
+
+ // Collect result
+ float distance = distancesPointer.getAtIndex(JAVA_FLOAT, i);
+ knnCollector.collect(id, scaler.scale(distance));
+ }
+ }
+ }
+
+ @Override
+ public void write(IndexOutput output) {
+ try (Arena temp = Arena.ofConfined()) {
+ MethodHandle writerHandle = WRITE_BYTES_HANDLE.bindTo(output);
+ MemorySegment writerStub =
+ getStub(
+ temp,
+ writerHandle,
+ FunctionDescriptor.of(JAVA_LONG, ADDRESS, JAVA_LONG, JAVA_LONG));
+
+ MemorySegment pointer = temp.allocate(ADDRESS);
+ handleException(wrapper.faiss_CustomIOWriter_new(pointer, writerStub));
+ MemorySegment customIOWriterPointer =
+ pointer
+ .get(ADDRESS, 0)
+ // Ensure timely cleanup
+ .reinterpret(temp, wrapper::faiss_CustomIOWriter_free);
+
+ // Write index
+ handleException(
+ wrapper.faiss_write_index_custom(
+ indexPointer, customIOWriterPointer, FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY));
+ }
+ }
+ }
+}
diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java
new file mode 100644
index 000000000000..575ebf953224
--- /dev/null
+++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java
@@ -0,0 +1,447 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.sandbox.codecs.faiss;
+
+import static java.lang.foreign.ValueLayout.ADDRESS;
+import static java.lang.foreign.ValueLayout.JAVA_INT;
+import static java.lang.foreign.ValueLayout.JAVA_LONG;
+
+import java.lang.foreign.FunctionDescriptor;
+import java.lang.foreign.Linker;
+import java.lang.foreign.MemorySegment;
+import java.lang.foreign.SymbolLookup;
+import java.lang.invoke.MethodHandle;
+import java.util.Arrays;
+import java.util.Locale;
+
+/**
+ * Utility class to wrap necessary functions of the native C API of Faiss
+ * using Project Panama.
+ *
+ * @lucene.experimental
+ */
+@SuppressWarnings("restricted") // uses unsafe calls
+final class FaissNativeWrapper {
+ static {
+ System.loadLibrary(FaissLibrary.NAME);
+ }
+
+ private static MethodHandle getHandle(String functionName, FunctionDescriptor descriptor) {
+ return Linker.nativeLinker()
+ .downcallHandle(SymbolLookup.loaderLookup().findOrThrow(functionName), descriptor);
+ }
+
+ FaissNativeWrapper() {
+ // Check Faiss version
+ String expectedVersion = FaissLibrary.VERSION;
+ String actualVersion = faiss_get_version().reinterpret(Long.MAX_VALUE).getString(0);
+
+ if (expectedVersion.equals(actualVersion) == false) {
+ throw new LinkageError(
+ String.format(
+ Locale.ROOT,
+ "Expected Faiss library version %s, found %s",
+ expectedVersion,
+ actualVersion));
+ }
+ }
+
+ private final MethodHandle faiss_get_version$MH =
+ getHandle("faiss_get_version", FunctionDescriptor.of(ADDRESS));
+
+ MemorySegment faiss_get_version() {
+ try {
+ return (MemorySegment) faiss_get_version$MH.invokeExact();
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_CustomIOReader_free$MH =
+ getHandle("faiss_CustomIOReader_free", FunctionDescriptor.ofVoid(ADDRESS));
+
+ void faiss_CustomIOReader_free(MemorySegment customIOReaderPointer) {
+ try {
+ faiss_CustomIOReader_free$MH.invokeExact(customIOReaderPointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_CustomIOReader_new$MH =
+ getHandle("faiss_CustomIOReader_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS));
+
+ int faiss_CustomIOReader_new(MemorySegment pointer, MemorySegment readerStub) {
+ try {
+ return (int) faiss_CustomIOReader_new$MH.invokeExact(pointer, readerStub);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_CustomIOWriter_free$MH =
+ getHandle("faiss_CustomIOWriter_free", FunctionDescriptor.ofVoid(ADDRESS));
+
+ void faiss_CustomIOWriter_free(MemorySegment customIOWriterPointer) {
+ try {
+ faiss_CustomIOWriter_free$MH.invokeExact(customIOWriterPointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_CustomIOWriter_new$MH =
+ getHandle("faiss_CustomIOWriter_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS));
+
+ int faiss_CustomIOWriter_new(MemorySegment pointer, MemorySegment writerStub) {
+ try {
+ return (int) faiss_CustomIOWriter_new$MH.invokeExact(pointer, writerStub);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_IDSelectorBitmap_free$MH =
+ getHandle("faiss_IDSelectorBitmap_free", FunctionDescriptor.ofVoid(ADDRESS));
+
+ void faiss_IDSelectorBitmap_free(MemorySegment idSelectorBitmapPointer) {
+ try {
+ faiss_IDSelectorBitmap_free$MH.invokeExact(idSelectorBitmapPointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_IDSelectorBitmap_new$MH =
+ getHandle(
+ "faiss_IDSelectorBitmap_new",
+ FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS));
+
+ int faiss_IDSelectorBitmap_new(MemorySegment pointer, long length, MemorySegment bitmapPointer) {
+ try {
+ return (int) faiss_IDSelectorBitmap_new$MH.invokeExact(pointer, length, bitmapPointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_Index_add_with_ids$MH =
+ getHandle(
+ "faiss_Index_add_with_ids",
+ FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS));
+
+ int faiss_Index_add_with_ids(
+ MemorySegment indexPointer, long size, MemorySegment docsPointer, MemorySegment idsPointer) {
+ try {
+ return (int)
+ faiss_Index_add_with_ids$MH.invokeExact(indexPointer, size, docsPointer, idsPointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_Index_free$MH =
+ getHandle("faiss_Index_free", FunctionDescriptor.ofVoid(ADDRESS));
+
+ void faiss_Index_free(MemorySegment indexPointer) {
+ try {
+ faiss_Index_free$MH.invokeExact(indexPointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_Index_is_trained$MH =
+ getHandle("faiss_Index_is_trained", FunctionDescriptor.of(JAVA_INT, ADDRESS));
+
+ int faiss_Index_is_trained(MemorySegment indexPointer) {
+ try {
+ return (int) faiss_Index_is_trained$MH.invokeExact(indexPointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_Index_metric_type$MH =
+ getHandle("faiss_Index_metric_type", FunctionDescriptor.of(JAVA_INT, ADDRESS));
+
+ int faiss_Index_metric_type(MemorySegment indexPointer) {
+ try {
+ return (int) faiss_Index_metric_type$MH.invokeExact(indexPointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_Index_search$MH =
+ getHandle(
+ "faiss_Index_search",
+ FunctionDescriptor.of(
+ JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS));
+
+ int faiss_Index_search(
+ MemorySegment indexPointer,
+ long numQueries,
+ MemorySegment queriesPointer,
+ long k,
+ MemorySegment distancesPointer,
+ MemorySegment idsPointer) {
+ try {
+ return (int)
+ faiss_Index_search$MH.invokeExact(
+ indexPointer, numQueries, queriesPointer, k, distancesPointer, idsPointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_Index_search_with_params$MH =
+ getHandle(
+ "faiss_Index_search_with_params",
+ FunctionDescriptor.of(
+ JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS, ADDRESS));
+
+ int faiss_Index_search_with_params(
+ MemorySegment indexPointer,
+ long numQueries,
+ MemorySegment queriesPointer,
+ long k,
+ MemorySegment searchParametersPointer,
+ MemorySegment distancesPointer,
+ MemorySegment idsPointer) {
+ try {
+ return (int)
+ faiss_Index_search_with_params$MH.invokeExact(
+ indexPointer,
+ numQueries,
+ queriesPointer,
+ k,
+ searchParametersPointer,
+ distancesPointer,
+ idsPointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_Index_train$MH =
+ getHandle("faiss_Index_train", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS));
+
+ int faiss_Index_train(MemorySegment indexPointer, long size, MemorySegment docsPointer) {
+ try {
+ return (int) faiss_Index_train$MH.invokeExact(indexPointer, size, docsPointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_ParameterSpace_free$MH =
+ getHandle("faiss_ParameterSpace_free", FunctionDescriptor.ofVoid(ADDRESS));
+
+ void faiss_ParameterSpace_free(MemorySegment parameterSpacePointer) {
+ try {
+ faiss_ParameterSpace_free$MH.invokeExact(parameterSpacePointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_ParameterSpace_new$MH =
+ getHandle("faiss_ParameterSpace_new", FunctionDescriptor.of(JAVA_INT, ADDRESS));
+
+ int faiss_ParameterSpace_new(MemorySegment pointer) {
+ try {
+ return (int) faiss_ParameterSpace_new$MH.invokeExact(pointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_ParameterSpace_set_index_parameters$MH =
+ getHandle(
+ "faiss_ParameterSpace_set_index_parameters",
+ FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS));
+
+ int faiss_ParameterSpace_set_index_parameters(
+ MemorySegment parameterSpacePointer,
+ MemorySegment indexPointer,
+ MemorySegment descriptionPointer) {
+ try {
+ return (int)
+ faiss_ParameterSpace_set_index_parameters$MH.invokeExact(
+ parameterSpacePointer, indexPointer, descriptionPointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_SearchParameters_free$MH =
+ getHandle("faiss_SearchParameters_free", FunctionDescriptor.ofVoid(ADDRESS));
+
+ void faiss_SearchParameters_free(MemorySegment searchParametersPointer) {
+ try {
+ faiss_SearchParameters_free$MH.invokeExact(searchParametersPointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_SearchParameters_new$MH =
+ getHandle("faiss_SearchParameters_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS));
+
+ int faiss_SearchParameters_new(MemorySegment pointer, MemorySegment idSelectorBitmapPointer) {
+ try {
+ return (int) faiss_SearchParameters_new$MH.invokeExact(pointer, idSelectorBitmapPointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_index_factory$MH =
+ getHandle(
+ "faiss_index_factory",
+ FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT));
+
+ int faiss_index_factory(
+ MemorySegment pointer, int dimension, MemorySegment description, int metric) {
+ try {
+ return (int) faiss_index_factory$MH.invokeExact(pointer, dimension, description, metric);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_read_index_custom$MH =
+ getHandle(
+ "faiss_read_index_custom", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS));
+
+ int faiss_read_index_custom(
+ MemorySegment customIOReaderPointer, int ioFlags, MemorySegment pointer) {
+ try {
+ return (int) faiss_read_index_custom$MH.invokeExact(customIOReaderPointer, ioFlags, pointer);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private final MethodHandle faiss_write_index_custom$MH =
+ getHandle(
+ "faiss_write_index_custom", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
+
+ int faiss_write_index_custom(
+ MemorySegment indexPointer, MemorySegment customIOWriterPointer, int ioFlags) {
+ try {
+ return (int)
+ faiss_write_index_custom$MH.invokeExact(indexPointer, customIOWriterPointer, ioFlags);
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ /**
+ * Exception used to rethrow handled Faiss errors in native code.
+ *
+ * @lucene.experimental
+ */
+ static class Exception extends RuntimeException {
+ // See error codes defined in c_api/error_c.h
+ enum ErrorCode {
+ /// No error
+ OK(0),
+ /// Any exception other than Faiss or standard C++ library exceptions
+ UNKNOWN_EXCEPT(-1),
+ /// Faiss library exception
+ FAISS_EXCEPT(-2),
+ /// Standard C++ library exception
+ STD_EXCEPT(-4);
+
+ private final int code;
+
+ ErrorCode(int code) {
+ this.code = code;
+ }
+
+ static ErrorCode fromCode(int code) {
+ return Arrays.stream(ErrorCode.values())
+ .filter(errorCode -> errorCode.code == code)
+ .findFirst()
+ .orElseThrow();
+ }
+ }
+
+ private Exception(int code) {
+ super(
+ String.format(
+ Locale.ROOT,
+ "%s[%s(%d)]",
+ Exception.class.getName(),
+ ErrorCode.fromCode(code),
+ code));
+ }
+
+ static void handleException(int code) {
+ if (code < 0) {
+ throw new Exception(code);
+ }
+ }
+ }
+}
diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/LibFaissC.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/LibFaissC.java
deleted file mode 100644
index d9a7baa8ef3c..000000000000
--- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/LibFaissC.java
+++ /dev/null
@@ -1,534 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.lucene.sandbox.codecs.faiss;
-
-import static java.lang.foreign.ValueLayout.ADDRESS;
-import static java.lang.foreign.ValueLayout.JAVA_BYTE;
-import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
-import static java.lang.foreign.ValueLayout.JAVA_INT;
-import static java.lang.foreign.ValueLayout.JAVA_LONG;
-import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
-
-import java.io.IOException;
-import java.lang.foreign.Arena;
-import java.lang.foreign.FunctionDescriptor;
-import java.lang.foreign.Linker;
-import java.lang.foreign.MemorySegment;
-import java.lang.foreign.SymbolLookup;
-import java.lang.invoke.MethodHandle;
-import java.lang.invoke.MethodHandles;
-import java.lang.invoke.MethodType;
-import java.nio.ByteOrder;
-import java.util.Arrays;
-import java.util.Locale;
-import org.apache.lucene.index.FloatVectorValues;
-import org.apache.lucene.index.KnnVectorValues;
-import org.apache.lucene.index.VectorSimilarityFunction;
-import org.apache.lucene.search.KnnCollector;
-import org.apache.lucene.store.IndexInput;
-import org.apache.lucene.store.IndexOutput;
-import org.apache.lucene.util.Bits;
-import org.apache.lucene.util.FixedBitSet;
-import org.apache.lucene.util.hnsw.IntToIntFunction;
-
-/**
- * Utility class to wrap necessary functions of the native C API of Faiss
- * using Project Panama.
- *
- * @lucene.experimental
- */
-@SuppressWarnings("restricted") // uses unsafe calls
-final class LibFaissC {
- // TODO: Use vectorized version where available
- public static final String LIBRARY_NAME = "faiss_c";
- public static final String LIBRARY_VERSION = "1.11.0";
-
- // See flags defined in c_api/index_io_c.h
- static final int FAISS_IO_FLAG_MMAP = 1;
- static final int FAISS_IO_FLAG_READ_ONLY = 2;
-
- private static final int BUFFER_SIZE = 256 * 1024 * 1024; // 256 MB
-
- static {
- System.loadLibrary(LIBRARY_NAME);
- checkLibraryVersion();
- }
-
- private LibFaissC() {}
-
- private static MemorySegment getUpcallStub(
- Arena arena, MethodHandle target, FunctionDescriptor descriptor) {
- return Linker.nativeLinker().upcallStub(target, descriptor, arena);
- }
-
- private static MethodHandle getDowncallHandle(
- String functionName, FunctionDescriptor descriptor) {
- return Linker.nativeLinker()
- .downcallHandle(SymbolLookup.loaderLookup().findOrThrow(functionName), descriptor);
- }
-
- private static void checkLibraryVersion() {
- MethodHandle getVersion =
- getDowncallHandle("faiss_get_version", FunctionDescriptor.of(ADDRESS));
-
- MemorySegment nativeString = call(getVersion);
- String actualVersion = nativeString.reinterpret(Long.MAX_VALUE).getString(0);
-
- if (LIBRARY_VERSION.equals(actualVersion) == false) {
- throw new UnsupportedOperationException(
- String.format(
- Locale.ROOT,
- "Expected Faiss library version %s, found %s",
- LIBRARY_VERSION,
- actualVersion));
- }
- }
-
- private static final MethodHandle FREE_INDEX =
- getDowncallHandle("faiss_Index_free", FunctionDescriptor.ofVoid(ADDRESS));
-
- public static void freeIndex(MemorySegment indexPointer) {
- call(FREE_INDEX, indexPointer);
- }
-
- private static final MethodHandle FREE_CUSTOM_IO_WRITER =
- getDowncallHandle("faiss_CustomIOWriter_free", FunctionDescriptor.ofVoid(ADDRESS));
-
- public static void freeCustomIOWriter(MemorySegment customIOWriterPointer) {
- call(FREE_CUSTOM_IO_WRITER, customIOWriterPointer);
- }
-
- private static final MethodHandle FREE_CUSTOM_IO_READER =
- getDowncallHandle("faiss_CustomIOReader_free", FunctionDescriptor.ofVoid(ADDRESS));
-
- public static void freeCustomIOReader(MemorySegment customIOReaderPointer) {
- call(FREE_CUSTOM_IO_READER, customIOReaderPointer);
- }
-
- private static final MethodHandle FREE_PARAMETER_SPACE =
- getDowncallHandle("faiss_ParameterSpace_free", FunctionDescriptor.ofVoid(ADDRESS));
-
- private static void freeParameterSpace(MemorySegment parameterSpacePointer) {
- call(FREE_PARAMETER_SPACE, parameterSpacePointer);
- }
-
- private static final MethodHandle FREE_ID_SELECTOR_BITMAP =
- getDowncallHandle("faiss_IDSelectorBitmap_free", FunctionDescriptor.ofVoid(ADDRESS));
-
- private static void freeIDSelectorBitmap(MemorySegment idSelectorBitmapPointer) {
- call(FREE_ID_SELECTOR_BITMAP, idSelectorBitmapPointer);
- }
-
- private static final MethodHandle FREE_SEARCH_PARAMETERS =
- getDowncallHandle("faiss_SearchParameters_free", FunctionDescriptor.ofVoid(ADDRESS));
-
- private static void freeSearchParameters(MemorySegment searchParametersPointer) {
- call(FREE_SEARCH_PARAMETERS, searchParametersPointer);
- }
-
- private static final MethodHandle INDEX_FACTORY =
- getDowncallHandle(
- "faiss_index_factory",
- FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT));
-
- private static final MethodHandle PARAMETER_SPACE_NEW =
- getDowncallHandle("faiss_ParameterSpace_new", FunctionDescriptor.of(JAVA_INT, ADDRESS));
-
- private static final MethodHandle SET_INDEX_PARAMETERS =
- getDowncallHandle(
- "faiss_ParameterSpace_set_index_parameters",
- FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS));
-
- private static final MethodHandle ID_SELECTOR_BITMAP_NEW =
- getDowncallHandle(
- "faiss_IDSelectorBitmap_new",
- FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS));
-
- private static final MethodHandle SEARCH_PARAMETERS_NEW =
- getDowncallHandle(
- "faiss_SearchParameters_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS));
-
- private static final MethodHandle INDEX_IS_TRAINED =
- getDowncallHandle("faiss_Index_is_trained", FunctionDescriptor.of(JAVA_INT, ADDRESS));
-
- private static final MethodHandle INDEX_TRAIN =
- getDowncallHandle(
- "faiss_Index_train", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS));
-
- private static final MethodHandle INDEX_ADD_WITH_IDS =
- getDowncallHandle(
- "faiss_Index_add_with_ids",
- FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS));
-
- public static MemorySegment createIndex(
- String description,
- String indexParams,
- VectorSimilarityFunction function,
- FloatVectorValues floatVectorValues,
- IntToIntFunction oldToNewDocId)
- throws IOException {
-
- try (Arena temp = Arena.ofConfined()) {
- int size = floatVectorValues.size();
- int dimension = floatVectorValues.dimension();
-
- // Mapped from faiss/MetricType.h
- int metric =
- switch (function) {
- case DOT_PRODUCT -> 0;
- case EUCLIDEAN -> 1;
- case COSINE, MAXIMUM_INNER_PRODUCT ->
- throw new UnsupportedOperationException("Metric type not supported");
- };
-
- // Create an index
- MemorySegment pointer = temp.allocate(ADDRESS);
- callAndHandleError(INDEX_FACTORY, pointer, dimension, temp.allocateFrom(description), metric);
- MemorySegment indexPointer = pointer.get(ADDRESS, 0);
-
- // Set index params
- callAndHandleError(PARAMETER_SPACE_NEW, pointer);
- MemorySegment parameterSpacePointer =
- pointer
- .get(ADDRESS, 0)
- // Ensure timely cleanup
- .reinterpret(temp, LibFaissC::freeParameterSpace);
-
- callAndHandleError(
- SET_INDEX_PARAMETERS,
- parameterSpacePointer,
- indexPointer,
- temp.allocateFrom(indexParams));
-
- // TODO: Improve memory usage (with a tradeoff in performance) by batched indexing, see:
- // - https://github.com/opensearch-project/k-NN/issues/1506
- // - https://github.com/opensearch-project/k-NN/issues/1938
-
- // Allocate docs in native memory
- MemorySegment docs = temp.allocate(JAVA_FLOAT, (long) size * dimension);
- long docsOffset = 0;
- long perDocByteSize = dimension * JAVA_FLOAT.byteSize();
-
- // Allocate ids in native memory
- MemorySegment ids = temp.allocate(JAVA_LONG, size);
- int idsIndex = 0;
-
- KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
- for (int i = iterator.nextDoc(); i != NO_MORE_DOCS; i = iterator.nextDoc()) {
- int id = oldToNewDocId.apply(i);
- ids.setAtIndex(JAVA_LONG, idsIndex, id);
- idsIndex++;
-
- float[] vector = floatVectorValues.vectorValue(iterator.index());
- MemorySegment.copy(vector, 0, docs, JAVA_FLOAT, docsOffset, vector.length);
- docsOffset += perDocByteSize;
- }
-
- // Train index
- int isTrained = call(INDEX_IS_TRAINED, indexPointer);
- if (isTrained == 0) {
- callAndHandleError(INDEX_TRAIN, indexPointer, size, docs);
- }
-
- // Add docs to index
- callAndHandleError(INDEX_ADD_WITH_IDS, indexPointer, size, docs, ids);
-
- return indexPointer;
- }
- }
-
- @SuppressWarnings("unused") // called using a MethodHandle
- private static long writeBytes(
- IndexOutput output, MemorySegment inputPointer, long itemSize, long numItems)
- throws IOException {
- long size = itemSize * numItems;
- inputPointer = inputPointer.reinterpret(size);
-
- if (size <= BUFFER_SIZE) { // simple case, avoid buffering
- output.writeBytes(inputPointer.toArray(JAVA_BYTE), (int) size);
- } else { // copy buffered number of bytes repeatedly
- byte[] bytes = new byte[BUFFER_SIZE];
- for (long offset = 0; offset < size; offset += BUFFER_SIZE) {
- int length = (int) Math.min(size - offset, BUFFER_SIZE);
- MemorySegment.copy(inputPointer, JAVA_BYTE, offset, bytes, 0, length);
- output.writeBytes(bytes, length);
- }
- }
- return numItems;
- }
-
- @SuppressWarnings("unused") // called using a MethodHandle
- private static long readBytes(
- IndexInput input, MemorySegment outputPointer, long itemSize, long numItems)
- throws IOException {
- long size = itemSize * numItems;
- outputPointer = outputPointer.reinterpret(size);
-
- if (size <= BUFFER_SIZE) { // simple case, avoid buffering
- byte[] bytes = new byte[(int) size];
- input.readBytes(bytes, 0, bytes.length);
- MemorySegment.copy(bytes, 0, outputPointer, JAVA_BYTE, 0, bytes.length);
- } else { // copy buffered number of bytes repeatedly
- byte[] bytes = new byte[BUFFER_SIZE];
- for (long offset = 0; offset < size; offset += BUFFER_SIZE) {
- int length = (int) Math.min(size - offset, BUFFER_SIZE);
- input.readBytes(bytes, 0, length);
- MemorySegment.copy(bytes, 0, outputPointer, JAVA_BYTE, offset, length);
- }
- }
- return numItems;
- }
-
- private static final MethodHandle WRITE_BYTES_HANDLE;
- private static final MethodHandle READ_BYTES_HANDLE;
-
- static {
- try {
- WRITE_BYTES_HANDLE =
- MethodHandles.lookup()
- .findStatic(
- LibFaissC.class,
- "writeBytes",
- MethodType.methodType(
- long.class, IndexOutput.class, MemorySegment.class, long.class, long.class));
-
- READ_BYTES_HANDLE =
- MethodHandles.lookup()
- .findStatic(
- LibFaissC.class,
- "readBytes",
- MethodType.methodType(
- long.class, IndexInput.class, MemorySegment.class, long.class, long.class));
- } catch (NoSuchMethodException | IllegalAccessException e) {
- throw new RuntimeException(e);
- }
- }
-
- private static final MethodHandle CUSTOM_IO_WRITER_NEW =
- getDowncallHandle(
- "faiss_CustomIOWriter_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS));
-
- private static final MethodHandle WRITE_INDEX_CUSTOM =
- getDowncallHandle(
- "faiss_write_index_custom", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
-
- public static void indexWrite(MemorySegment indexPointer, IndexOutput output, int ioFlags) {
- try (Arena temp = Arena.ofConfined()) {
- MethodHandle writerHandle = WRITE_BYTES_HANDLE.bindTo(output);
- MemorySegment writerStub =
- getUpcallStub(
- temp, writerHandle, FunctionDescriptor.of(JAVA_LONG, ADDRESS, JAVA_LONG, JAVA_LONG));
-
- MemorySegment pointer = temp.allocate(ADDRESS);
- callAndHandleError(CUSTOM_IO_WRITER_NEW, pointer, writerStub);
- MemorySegment customIOWriterPointer =
- pointer
- .get(ADDRESS, 0)
- // Ensure timely cleanup
- .reinterpret(temp, LibFaissC::freeCustomIOWriter);
-
- callAndHandleError(WRITE_INDEX_CUSTOM, indexPointer, customIOWriterPointer, ioFlags);
- }
- }
-
- private static final MethodHandle CUSTOM_IO_READER_NEW =
- getDowncallHandle(
- "faiss_CustomIOReader_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS));
-
- private static final MethodHandle READ_INDEX_CUSTOM =
- getDowncallHandle(
- "faiss_read_index_custom", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS));
-
- public static MemorySegment indexRead(IndexInput input, int ioFlags) {
- try (Arena temp = Arena.ofConfined()) {
- MethodHandle readerHandle = READ_BYTES_HANDLE.bindTo(input);
- MemorySegment readerStub =
- getUpcallStub(
- temp, readerHandle, FunctionDescriptor.of(JAVA_LONG, ADDRESS, JAVA_LONG, JAVA_LONG));
-
- MemorySegment pointer = temp.allocate(ADDRESS);
- callAndHandleError(CUSTOM_IO_READER_NEW, pointer, readerStub);
- MemorySegment customIOReaderPointer =
- pointer
- .get(ADDRESS, 0)
- // Ensure timely cleanup
- .reinterpret(temp, LibFaissC::freeCustomIOReader);
-
- callAndHandleError(READ_INDEX_CUSTOM, customIOReaderPointer, ioFlags, pointer);
- return pointer.get(ADDRESS, 0);
- }
- }
-
- private static final MethodHandle INDEX_SEARCH =
- getDowncallHandle(
- "faiss_Index_search",
- FunctionDescriptor.of(
- JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS));
-
- private static final MethodHandle INDEX_SEARCH_WITH_PARAMS =
- getDowncallHandle(
- "faiss_Index_search_with_params",
- FunctionDescriptor.of(
- JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS, ADDRESS));
-
- public static void indexSearch(
- MemorySegment indexPointer,
- VectorSimilarityFunction function,
- float[] query,
- KnnCollector knnCollector,
- Bits acceptDocs) {
-
- try (Arena temp = Arena.ofConfined()) {
- FixedBitSet fixedBitSet =
- switch (acceptDocs) {
- case null -> null;
- case FixedBitSet bitSet -> bitSet;
- // TODO: Add optimized case for SparseFixedBitSet
- case Bits bits -> FixedBitSet.copyOf(bits);
- };
-
- // Allocate queries in native memory
- MemorySegment queries = temp.allocateFrom(JAVA_FLOAT, query);
-
- // Faiss knn search
- int k = knnCollector.k();
- MemorySegment distancesPointer = temp.allocate(JAVA_FLOAT, k);
- MemorySegment idsPointer = temp.allocate(JAVA_LONG, k);
-
- MemorySegment localIndex = indexPointer.reinterpret(temp, null);
- if (fixedBitSet == null) {
- // Search without runtime filters
- callAndHandleError(INDEX_SEARCH, localIndex, 1, queries, k, distancesPointer, idsPointer);
- } else {
- MemorySegment pointer = temp.allocate(ADDRESS);
-
- long[] bits = fixedBitSet.getBits();
- MemorySegment nativeBits =
- // Use LITTLE_ENDIAN to convert long[] -> uint8_t*
- temp.allocateFrom(JAVA_LONG.withOrder(ByteOrder.LITTLE_ENDIAN), bits);
-
- callAndHandleError(ID_SELECTOR_BITMAP_NEW, pointer, fixedBitSet.length(), nativeBits);
- MemorySegment idSelectorBitmapPointer =
- pointer
- .get(ADDRESS, 0)
- // Ensure timely cleanup
- .reinterpret(temp, LibFaissC::freeIDSelectorBitmap);
-
- callAndHandleError(SEARCH_PARAMETERS_NEW, pointer, idSelectorBitmapPointer);
- MemorySegment searchParametersPointer =
- pointer
- .get(ADDRESS, 0)
- // Ensure timely cleanup
- .reinterpret(temp, LibFaissC::freeSearchParameters);
-
- // Search with runtime filters
- callAndHandleError(
- INDEX_SEARCH_WITH_PARAMS,
- localIndex,
- 1,
- queries,
- k,
- searchParametersPointer,
- distancesPointer,
- idsPointer);
- }
-
- // Retrieve scores and ids
- float[] distances = distancesPointer.toArray(JAVA_FLOAT);
- long[] ids = idsPointer.toArray(JAVA_LONG);
-
- // Record hits
- for (int i = 0; i < k; i++) {
- // Not enough results
- if (ids[i] == -1) {
- break;
- }
-
- // Scale Faiss distances to Lucene scores, see VectorSimilarityFunction.java
- float score =
- switch (function) {
- case DOT_PRODUCT ->
- // distance in Faiss === dotProduct in Lucene
- Math.max((1 + distances[i]) / 2, 0);
-
- case EUCLIDEAN ->
- // distance in Faiss === squareDistance in Lucene
- 1 / (1 + distances[i]);
-
- case COSINE, MAXIMUM_INNER_PRODUCT ->
- throw new UnsupportedOperationException("Metric type not supported");
- };
-
- knnCollector.collect((int) ids[i], score);
- }
- }
- }
-
- @SuppressWarnings("unchecked")
- private static T call(MethodHandle handle, Object... args) {
- try {
- return (T) handle.invokeWithArguments(args);
- } catch (Throwable e) {
- throw new RuntimeException(e);
- }
- }
-
- private static void callAndHandleError(MethodHandle handle, Object... args) {
- int returnCode = call(handle, args);
- if (returnCode < 0) {
- // TODO: Surface actual exception in a thread-safe manner?
- throw new FaissException(returnCode);
- }
- }
-
- /**
- * Exception used to rethrow handled Faiss errors in native code.
- *
- * @lucene.experimental
- */
- public static class FaissException extends RuntimeException {
- // See error codes defined in c_api/error_c.h
- enum ErrorCode {
- /// No error
- OK(0),
- /// Any exception other than Faiss or standard C++ library exceptions
- UNKNOWN_EXCEPT(-1),
- /// Faiss library exception
- FAISS_EXCEPT(-2),
- /// Standard C++ library exception
- STD_EXCEPT(-4);
-
- private final int code;
-
- ErrorCode(int code) {
- this.code = code;
- }
-
- static ErrorCode fromCode(int code) {
- return Arrays.stream(ErrorCode.values())
- .filter(errorCode -> errorCode.code == code)
- .findFirst()
- .orElseThrow();
- }
- }
-
- public FaissException(int code) {
- super(String.format(Locale.ROOT, "Faiss library ran into %s", ErrorCode.fromCode(code)));
- }
- }
-}
diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java
index e63fa3070f96..bd4a5b88d2aa 100644
--- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java
+++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java
@@ -22,8 +22,7 @@
* org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat}.
*
* To use this format: Install pytorch/faiss-cpu v{@value
- * org.apache.lucene.sandbox.codecs.faiss.LibFaissC#LIBRARY_VERSION} from pytorch/faiss-cpu from Conda and place shared libraries (including
* dependencies) on the {@code $LD_LIBRARY_PATH} environment variable or {@code -Djava.library.path}
* JVM argument.
@@ -39,9 +38,11 @@
*
Install micromamba (an open-source Conda
* package manager) or similar
* Install dependencies using {@code micromamba create -n faiss-env -c pytorch -c conda-forge
- * -y faiss-cpu=}{@value org.apache.lucene.sandbox.codecs.faiss.LibFaissC#LIBRARY_VERSION}
+ * -y faiss-cpu=}{@value org.apache.lucene.sandbox.codecs.faiss.FaissLibrary#VERSION}
* Activate environment using {@code micromamba activate faiss-env}
* Add shared libraries to runtime using {@code export LD_LIBRARY_PATH=$CONDA_PREFIX/lib}
+ * (verify that the {@value org.apache.lucene.sandbox.codecs.faiss.FaissLibrary#NAME} library
+ * is present here)
* And you're good to go! (add the {@code -Dtests.faiss.run=true} JVM argument to ensure Faiss
* tests are run)
*
diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java
index 4a3fb3661e6d..f66eae680db1 100644
--- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java
+++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java
@@ -48,21 +48,21 @@ public class TestFaissKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
private static final VectorSimilarityFunction[] SUPPORTED_FUNCTIONS = {DOT_PRODUCT, EUCLIDEAN};
@BeforeClass
- public static void maybeSuppress() throws ClassNotFoundException {
+ public static void maybeSuppress() {
// Explicitly run tests
if (Boolean.getBoolean(FAISS_RUN_TESTS)) {
return;
}
// Otherwise check if dependencies are present
- boolean faissLibraryPresent;
+ boolean dependenciesPresent;
try {
- Class.forName("org.apache.lucene.sandbox.codecs.faiss.LibFaissC");
- faissLibraryPresent = true;
- } catch (UnsatisfiedLinkError _) {
- faissLibraryPresent = false;
+ FaissLibrary _ = FaissLibrary.INSTANCE;
+ dependenciesPresent = true;
+ } catch (LinkageError _) {
+ dependenciesPresent = false;
}
- assumeTrue("Native libraries present", faissLibraryPresent);
+ assumeTrue("Dependencies present", dependenciesPresent);
}
@Override