diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java index 83beae607dc5..8cef8d29ce74 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java @@ -31,7 +31,7 @@ import org.apache.lucene.index.SegmentWriteState; /** - * A Faiss-based format to create and search vector indexes, using {@link LibFaissC} to interact + * A Faiss-based format to create and search vector indexes, using {@link FaissLibrary} to interact * with the native library. * *

The Faiss index is configured using its flexible indexMap; - private final Arena arena; + private final Map indexMap; private boolean closed; public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException { this.rawVectorsReader = rawVectorsReader; this.indexMap = new HashMap<>(); - this.arena = Arena.ofShared(); - this.closed = false; List fieldMetaList = new ArrayList<>(); String metaFileName = @@ -125,9 +115,11 @@ public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVector CodecUtil.retrieveChecksum(data); for (FieldMeta fieldMeta : fieldMetaList) { - if (indexMap.put(fieldMeta.fieldInfo.name, loadField(data, arena, fieldMeta)) != null) { - throw new CorruptIndexException("Duplicate field: " + fieldMeta.fieldInfo.name, meta); + if (indexMap.containsKey(fieldMeta.name)) { + throw new CorruptIndexException("Duplicate field: " + fieldMeta.name, meta); } + IndexInput indexInput = data.slice(fieldMeta.name, fieldMeta.offset, fieldMeta.length); + indexMap.put(fieldMeta.name, FaissLibrary.INSTANCE.readIndex(indexInput)); } } catch (Throwable t) { IOUtils.closeWhileSuppressingExceptions(t, this); @@ -150,21 +142,7 @@ private static FieldMeta parseNextField(IndexInput meta, SegmentReadState state) long dataOffset = meta.readLong(); long dataLength = meta.readLong(); - return new FieldMeta(fieldInfo, dataOffset, dataLength); - } - - @SuppressWarnings("restricted") // TODO: encapsulate the unsafeness into the LibFaissC - private static IndexEntry loadField(IndexInput data, Arena arena, FieldMeta fieldMeta) - throws IOException { - int ioFlags = FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY; - - // Read index into memory - MemorySegment indexPointer = - indexRead(data.slice(fieldMeta.fieldInfo.name, fieldMeta.offset, fieldMeta.length), ioFlags) - // Ensure timely cleanup - .reinterpret(arena, LibFaissC::freeIndex); - - return new IndexEntry(indexPointer, fieldMeta.fieldInfo.getVectorSimilarityFunction()); + return new FieldMeta(fieldInfo.name, dataOffset, dataLength); } @Override @@ -188,9 +166,9 @@ public ByteVectorValues getByteVectorValues(String field) { @Override public void search(String field, float[] vector, KnnCollector knnCollector, Bits acceptDocs) { - IndexEntry entry = indexMap.get(field); - if (entry != null) { - indexSearch(entry.indexPointer, entry.function, vector, knnCollector, acceptDocs); + FaissLibrary.Index index = indexMap.get(field); + if (index != null) { + index.search(vector, knnCollector, acceptDocs); } } @@ -210,12 +188,16 @@ public Map getOffHeapByteSize(FieldInfo fieldInfo) { @Override public void close() throws IOException { if (closed == false) { + // Close all indexes + for (FaissLibrary.Index index : indexMap.values()) { + index.close(); + } + indexMap.clear(); + + IOUtils.close(rawVectorsReader, data); closed = true; - IOUtils.close(rawVectorsReader, arena::close, data, indexMap::clear); } } - private record FieldMeta(FieldInfo fieldInfo, long offset, long length) {} - - private record IndexEntry(MemorySegment indexPointer, VectorSimilarityFunction function) {} + private record FieldMeta(String name, long offset, long length) {} } diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java index a81864a9c25e..f41986d8e8ce 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java @@ -21,14 +21,8 @@ import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME; import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION; import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT; -import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_MMAP; -import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_READ_ONLY; -import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.createIndex; -import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexWrite; import java.io.IOException; -import java.lang.foreign.Arena; -import java.lang.foreign.MemorySegment; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -43,7 +37,6 @@ import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSet; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IOUtils; @@ -154,26 +147,23 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { } } - @SuppressWarnings("restricted") // TODO: encapsulate the unsafeness into the LibFaissC private void writeFloatField( FieldInfo fieldInfo, FloatVectorValues floatVectorValues, IntToIntFunction oldToNewDocId) throws IOException { int number = fieldInfo.number; meta.writeInt(number); - // Write index to temp file and deallocate from memory - try (Arena temp = Arena.ofConfined()) { - VectorSimilarityFunction function = fieldInfo.getVectorSimilarityFunction(); - MemorySegment indexPointer = - createIndex(description, indexParams, function, floatVectorValues, oldToNewDocId) - // Ensure timely cleanup - .reinterpret(temp, LibFaissC::freeIndex); - - int ioFlags = FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY; + try (FaissLibrary.Index index = + FaissLibrary.INSTANCE.createIndex( + description, + indexParams, + fieldInfo.getVectorSimilarityFunction(), + floatVectorValues, + oldToNewDocId)) { // Write index long dataOffset = data.getFilePointer(); - indexWrite(indexPointer, data, ioFlags); + index.write(data); long dataLength = data.getFilePointer() - dataOffset; meta.writeLong(dataOffset); @@ -233,7 +223,7 @@ public int size() { @Override public FloatVectorValues copy() { - return new BufferedFloatVectorValues(floats, dimension, docIdSet); + throw new AssertionError("Should not be called"); } @Override diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java new file mode 100644 index 000000000000..e7837692222d --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import java.io.Closeable; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.IntToIntFunction; + +/** + * Minimal interface to create and query Faiss indexes. + * + * @lucene.experimental + */ +interface FaissLibrary { + FaissLibrary INSTANCE = new FaissLibraryNativeImpl(); + + // TODO: Use SIMD version at runtime. The "faiss_c" library is linked to the main "faiss" library, + // which does not use SIMD instructions. However, there are SIMD versions of "faiss" (like + // "faiss_avx2", "faiss_avx512", "faiss_sve", etc.) available, which can be used by changing the + // dependencies of "faiss_c" using the "patchelf" utility. Figure out how to do this dynamically, + // or via modifications to upstream Faiss. + String NAME = "faiss_c"; + String VERSION = "1.11.0"; + + interface Index extends Closeable { + void search(float[] query, KnnCollector knnCollector, Bits acceptDocs); + + void write(IndexOutput output); + } + + Index createIndex( + String description, + String indexParams, + VectorSimilarityFunction function, + FloatVectorValues floatVectorValues, + IntToIntFunction oldToNewDocId); + + Index readIndex(IndexInput input); +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java new file mode 100644 index 000000000000..d72e65eca860 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_BYTE; +import static java.lang.foreign.ValueLayout.JAVA_FLOAT; +import static java.lang.foreign.ValueLayout.JAVA_LONG; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.sandbox.codecs.faiss.FaissNativeWrapper.Exception.handleException; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.lang.foreign.Arena; +import java.lang.foreign.FunctionDescriptor; +import java.lang.foreign.Linker; +import java.lang.foreign.MemorySegment; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.nio.ByteOrder; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.hnsw.IntToIntFunction; + +/** + * A native implementation of {@link FaissLibrary} using {@link FaissNativeWrapper}. + * + * @lucene.experimental + */ +@SuppressWarnings("restricted") // uses unsafe calls +final class FaissLibraryNativeImpl implements FaissLibrary { + private final FaissNativeWrapper wrapper; + + FaissLibraryNativeImpl() { + this.wrapper = new FaissNativeWrapper(); + } + + private static MemorySegment getStub( + Arena arena, MethodHandle target, FunctionDescriptor descriptor) { + return Linker.nativeLinker().upcallStub(target, descriptor, arena); + } + + private static final int BUFFER_SIZE = 256 * 1024 * 1024; // 256 MB + + @SuppressWarnings("unused") // called using a MethodHandle + private static long writeBytes( + IndexOutput output, MemorySegment inputPointer, long itemSize, long numItems) + throws IOException { + long size = itemSize * numItems; + inputPointer = inputPointer.reinterpret(size); + + if (size <= BUFFER_SIZE) { // simple case, avoid buffering + output.writeBytes(inputPointer.toArray(JAVA_BYTE), (int) size); + } else { // copy buffered number of bytes repeatedly + byte[] bytes = new byte[BUFFER_SIZE]; + for (long offset = 0; offset < size; offset += BUFFER_SIZE) { + int length = (int) Math.min(size - offset, BUFFER_SIZE); + MemorySegment.copy(inputPointer, JAVA_BYTE, offset, bytes, 0, length); + output.writeBytes(bytes, length); + } + } + return numItems; + } + + @SuppressWarnings("unused") // called using a MethodHandle + private static long readBytes( + IndexInput input, MemorySegment outputPointer, long itemSize, long numItems) + throws IOException { + long size = itemSize * numItems; + outputPointer = outputPointer.reinterpret(size); + + if (size <= BUFFER_SIZE) { // simple case, avoid buffering + byte[] bytes = new byte[(int) size]; + input.readBytes(bytes, 0, bytes.length); + MemorySegment.copy(bytes, 0, outputPointer, JAVA_BYTE, 0, bytes.length); + } else { // copy buffered number of bytes repeatedly + byte[] bytes = new byte[BUFFER_SIZE]; + for (long offset = 0; offset < size; offset += BUFFER_SIZE) { + int length = (int) Math.min(size - offset, BUFFER_SIZE); + input.readBytes(bytes, 0, length); + MemorySegment.copy(bytes, 0, outputPointer, JAVA_BYTE, offset, length); + } + } + return numItems; + } + + private static final MethodHandle WRITE_BYTES_HANDLE; + private static final MethodHandle READ_BYTES_HANDLE; + + static { + try { + MethodHandles.Lookup lookup = MethodHandles.lookup(); + + WRITE_BYTES_HANDLE = + lookup.findStatic( + FaissLibraryNativeImpl.class, + "writeBytes", + MethodType.methodType( + long.class, IndexOutput.class, MemorySegment.class, long.class, long.class)); + + READ_BYTES_HANDLE = + lookup.findStatic( + FaissLibraryNativeImpl.class, + "readBytes", + MethodType.methodType( + long.class, IndexInput.class, MemorySegment.class, long.class, long.class)); + + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new LinkageError( + "FaissLibraryNativeImpl reader / writer functions are missing or inaccessible", e); + } + } + + private static final Map FUNCTION_TO_METRIC = + Map.of( + // Mapped from faiss/MetricType.h + DOT_PRODUCT, 0, + EUCLIDEAN, 1); + + private static int functionToMetric(VectorSimilarityFunction function) { + Integer metric = FUNCTION_TO_METRIC.get(function); + if (metric == null) { + throw new UnsupportedOperationException("Similarity function not supported: " + function); + } + return metric; + } + + // Invert FUNCTION_TO_METRIC + private static final Map METRIC_TO_FUNCTION = + FUNCTION_TO_METRIC.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); + + private static VectorSimilarityFunction metricToFunction(int metric) { + VectorSimilarityFunction function = METRIC_TO_FUNCTION.get(metric); + if (function == null) { + throw new UnsupportedOperationException("Metric not supported: " + metric); + } + return function; + } + + @Override + public FaissLibrary.Index createIndex( + String description, + String indexParams, + VectorSimilarityFunction function, + FloatVectorValues floatVectorValues, + IntToIntFunction oldToNewDocId) { + + try (Arena temp = Arena.ofConfined()) { + int size = floatVectorValues.size(); + int dimension = floatVectorValues.dimension(); + int metric = functionToMetric(function); + + // Create an index + MemorySegment pointer = temp.allocate(ADDRESS); + handleException( + wrapper.faiss_index_factory(pointer, dimension, temp.allocateFrom(description), metric)); + + MemorySegment indexPointer = pointer.get(ADDRESS, 0); + + // Set index params + handleException(wrapper.faiss_ParameterSpace_new(pointer)); + MemorySegment parameterSpacePointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, wrapper::faiss_ParameterSpace_free); + + handleException( + wrapper.faiss_ParameterSpace_set_index_parameters( + parameterSpacePointer, indexPointer, temp.allocateFrom(indexParams))); + + // TODO: Improve memory usage (with a tradeoff in performance) by batched indexing, see: + // - https://github.com/opensearch-project/k-NN/issues/1506 + // - https://github.com/opensearch-project/k-NN/issues/1938 + + // Allocate docs in native memory + MemorySegment docs = temp.allocate(JAVA_FLOAT, (long) size * dimension); + long docsOffset = 0; + long perDocByteSize = dimension * JAVA_FLOAT.byteSize(); + + // Allocate ids in native memory + MemorySegment ids = temp.allocate(JAVA_LONG, size); + int idsIndex = 0; + + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int i = iterator.nextDoc(); i != NO_MORE_DOCS; i = iterator.nextDoc()) { + int id = oldToNewDocId.apply(i); + ids.setAtIndex(JAVA_LONG, idsIndex, id); + idsIndex++; + + float[] vector = floatVectorValues.vectorValue(iterator.index()); + MemorySegment.copy(vector, 0, docs, JAVA_FLOAT, docsOffset, vector.length); + docsOffset += perDocByteSize; + } + + // Train index + int isTrained = wrapper.faiss_Index_is_trained(indexPointer); + if (isTrained == 0) { + handleException(wrapper.faiss_Index_train(indexPointer, size, docs)); + } + + // Add docs to index + handleException(wrapper.faiss_Index_add_with_ids(indexPointer, size, docs, ids)); + + return new Index(indexPointer); + + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + // See flags defined in c_api/index_io_c.h + private static final int FAISS_IO_FLAG_MMAP = 1; + private static final int FAISS_IO_FLAG_READ_ONLY = 2; + + @Override + public FaissLibrary.Index readIndex(IndexInput input) { + try (Arena temp = Arena.ofConfined()) { + MethodHandle readerHandle = READ_BYTES_HANDLE.bindTo(input); + MemorySegment readerStub = + getStub( + temp, readerHandle, FunctionDescriptor.of(JAVA_LONG, ADDRESS, JAVA_LONG, JAVA_LONG)); + + MemorySegment pointer = temp.allocate(ADDRESS); + handleException(wrapper.faiss_CustomIOReader_new(pointer, readerStub)); + MemorySegment customIOReaderPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, wrapper::faiss_CustomIOReader_free); + + // Read index + handleException( + wrapper.faiss_read_index_custom( + customIOReaderPointer, FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY, pointer)); + MemorySegment indexPointer = pointer.get(ADDRESS, 0); + + return new Index(indexPointer); + } + } + + private class Index implements FaissLibrary.Index { + @FunctionalInterface + private interface FloatToFloatFunction { + float scale(float score); + } + + private final Arena arena; + private final MemorySegment indexPointer; + private final FloatToFloatFunction scaler; + private boolean closed; + + private Index(MemorySegment indexPointer) { + this.arena = Arena.ofShared(); + this.indexPointer = + indexPointer + // Ensure timely cleanup + .reinterpret(arena, wrapper::faiss_Index_free); + + // Get underlying function + int metricType = wrapper.faiss_Index_metric_type(indexPointer); + VectorSimilarityFunction function = metricToFunction(metricType); + + // Scale Faiss distances to Lucene scores, see VectorSimilarityFunction.java + this.scaler = + switch (function) { + case DOT_PRODUCT -> + // distance in Faiss === dotProduct in Lucene + distance -> Math.max((1 + distance) / 2, 0); + + case EUCLIDEAN -> + // distance in Faiss === squareDistance in Lucene + distance -> 1 / (1 + distance); + + case COSINE, MAXIMUM_INNER_PRODUCT -> throw new AssertionError("Should not reach here"); + }; + + this.closed = false; + } + + @Override + public void close() { + if (closed == false) { + arena.close(); + closed = true; + } + } + + @Override + public void search(float[] query, KnnCollector knnCollector, Bits acceptDocs) { + try (Arena temp = Arena.ofConfined()) { + FixedBitSet fixedBitSet = + switch (acceptDocs) { + case null -> null; + case FixedBitSet bitSet -> bitSet; + // TODO: Add optimized case for SparseFixedBitSet + case Bits bits -> FixedBitSet.copyOf(bits); + }; + + // Allocate queries in native memory + MemorySegment queries = temp.allocateFrom(JAVA_FLOAT, query); + + // Faiss knn search + int k = knnCollector.k(); + MemorySegment distancesPointer = temp.allocate(JAVA_FLOAT, k); + MemorySegment idsPointer = temp.allocate(JAVA_LONG, k); + + MemorySegment localIndex = indexPointer.reinterpret(temp, null); + if (fixedBitSet == null) { + // Search without runtime filters + handleException( + wrapper.faiss_Index_search(localIndex, 1, queries, k, distancesPointer, idsPointer)); + } else { + MemorySegment pointer = temp.allocate(ADDRESS); + + long[] bits = fixedBitSet.getBits(); + MemorySegment nativeBits = + // Use LITTLE_ENDIAN to convert long[] -> uint8_t* + temp.allocateFrom(JAVA_LONG.withOrder(ByteOrder.LITTLE_ENDIAN), bits); + + handleException( + wrapper.faiss_IDSelectorBitmap_new(pointer, fixedBitSet.length(), nativeBits)); + MemorySegment idSelectorBitmapPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, wrapper::faiss_IDSelectorBitmap_free); + + handleException(wrapper.faiss_SearchParameters_new(pointer, idSelectorBitmapPointer)); + MemorySegment searchParametersPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, wrapper::faiss_SearchParameters_free); + + // Search with runtime filters + handleException( + wrapper.faiss_Index_search_with_params( + localIndex, + 1, + queries, + k, + searchParametersPointer, + distancesPointer, + idsPointer)); + } + + // Record hits + for (int i = 0; i < k; i++) { + int id = (int) idsPointer.getAtIndex(JAVA_LONG, i); + + // Not enough results + if (id == -1) { + break; + } + + // Collect result + float distance = distancesPointer.getAtIndex(JAVA_FLOAT, i); + knnCollector.collect(id, scaler.scale(distance)); + } + } + } + + @Override + public void write(IndexOutput output) { + try (Arena temp = Arena.ofConfined()) { + MethodHandle writerHandle = WRITE_BYTES_HANDLE.bindTo(output); + MemorySegment writerStub = + getStub( + temp, + writerHandle, + FunctionDescriptor.of(JAVA_LONG, ADDRESS, JAVA_LONG, JAVA_LONG)); + + MemorySegment pointer = temp.allocate(ADDRESS); + handleException(wrapper.faiss_CustomIOWriter_new(pointer, writerStub)); + MemorySegment customIOWriterPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, wrapper::faiss_CustomIOWriter_free); + + // Write index + handleException( + wrapper.faiss_write_index_custom( + indexPointer, customIOWriterPointer, FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY)); + } + } + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java new file mode 100644 index 000000000000..575ebf953224 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_INT; +import static java.lang.foreign.ValueLayout.JAVA_LONG; + +import java.lang.foreign.FunctionDescriptor; +import java.lang.foreign.Linker; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SymbolLookup; +import java.lang.invoke.MethodHandle; +import java.util.Arrays; +import java.util.Locale; + +/** + * Utility class to wrap necessary functions of the native C API of Faiss + * using Project Panama. + * + * @lucene.experimental + */ +@SuppressWarnings("restricted") // uses unsafe calls +final class FaissNativeWrapper { + static { + System.loadLibrary(FaissLibrary.NAME); + } + + private static MethodHandle getHandle(String functionName, FunctionDescriptor descriptor) { + return Linker.nativeLinker() + .downcallHandle(SymbolLookup.loaderLookup().findOrThrow(functionName), descriptor); + } + + FaissNativeWrapper() { + // Check Faiss version + String expectedVersion = FaissLibrary.VERSION; + String actualVersion = faiss_get_version().reinterpret(Long.MAX_VALUE).getString(0); + + if (expectedVersion.equals(actualVersion) == false) { + throw new LinkageError( + String.format( + Locale.ROOT, + "Expected Faiss library version %s, found %s", + expectedVersion, + actualVersion)); + } + } + + private final MethodHandle faiss_get_version$MH = + getHandle("faiss_get_version", FunctionDescriptor.of(ADDRESS)); + + MemorySegment faiss_get_version() { + try { + return (MemorySegment) faiss_get_version$MH.invokeExact(); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_CustomIOReader_free$MH = + getHandle("faiss_CustomIOReader_free", FunctionDescriptor.ofVoid(ADDRESS)); + + void faiss_CustomIOReader_free(MemorySegment customIOReaderPointer) { + try { + faiss_CustomIOReader_free$MH.invokeExact(customIOReaderPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_CustomIOReader_new$MH = + getHandle("faiss_CustomIOReader_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + + int faiss_CustomIOReader_new(MemorySegment pointer, MemorySegment readerStub) { + try { + return (int) faiss_CustomIOReader_new$MH.invokeExact(pointer, readerStub); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_CustomIOWriter_free$MH = + getHandle("faiss_CustomIOWriter_free", FunctionDescriptor.ofVoid(ADDRESS)); + + void faiss_CustomIOWriter_free(MemorySegment customIOWriterPointer) { + try { + faiss_CustomIOWriter_free$MH.invokeExact(customIOWriterPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_CustomIOWriter_new$MH = + getHandle("faiss_CustomIOWriter_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + + int faiss_CustomIOWriter_new(MemorySegment pointer, MemorySegment writerStub) { + try { + return (int) faiss_CustomIOWriter_new$MH.invokeExact(pointer, writerStub); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_IDSelectorBitmap_free$MH = + getHandle("faiss_IDSelectorBitmap_free", FunctionDescriptor.ofVoid(ADDRESS)); + + void faiss_IDSelectorBitmap_free(MemorySegment idSelectorBitmapPointer) { + try { + faiss_IDSelectorBitmap_free$MH.invokeExact(idSelectorBitmapPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_IDSelectorBitmap_new$MH = + getHandle( + "faiss_IDSelectorBitmap_new", + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS)); + + int faiss_IDSelectorBitmap_new(MemorySegment pointer, long length, MemorySegment bitmapPointer) { + try { + return (int) faiss_IDSelectorBitmap_new$MH.invokeExact(pointer, length, bitmapPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_add_with_ids$MH = + getHandle( + "faiss_Index_add_with_ids", + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS)); + + int faiss_Index_add_with_ids( + MemorySegment indexPointer, long size, MemorySegment docsPointer, MemorySegment idsPointer) { + try { + return (int) + faiss_Index_add_with_ids$MH.invokeExact(indexPointer, size, docsPointer, idsPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_free$MH = + getHandle("faiss_Index_free", FunctionDescriptor.ofVoid(ADDRESS)); + + void faiss_Index_free(MemorySegment indexPointer) { + try { + faiss_Index_free$MH.invokeExact(indexPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_is_trained$MH = + getHandle("faiss_Index_is_trained", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + + int faiss_Index_is_trained(MemorySegment indexPointer) { + try { + return (int) faiss_Index_is_trained$MH.invokeExact(indexPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_metric_type$MH = + getHandle("faiss_Index_metric_type", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + + int faiss_Index_metric_type(MemorySegment indexPointer) { + try { + return (int) faiss_Index_metric_type$MH.invokeExact(indexPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_search$MH = + getHandle( + "faiss_Index_search", + FunctionDescriptor.of( + JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS)); + + int faiss_Index_search( + MemorySegment indexPointer, + long numQueries, + MemorySegment queriesPointer, + long k, + MemorySegment distancesPointer, + MemorySegment idsPointer) { + try { + return (int) + faiss_Index_search$MH.invokeExact( + indexPointer, numQueries, queriesPointer, k, distancesPointer, idsPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_search_with_params$MH = + getHandle( + "faiss_Index_search_with_params", + FunctionDescriptor.of( + JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS, ADDRESS)); + + int faiss_Index_search_with_params( + MemorySegment indexPointer, + long numQueries, + MemorySegment queriesPointer, + long k, + MemorySegment searchParametersPointer, + MemorySegment distancesPointer, + MemorySegment idsPointer) { + try { + return (int) + faiss_Index_search_with_params$MH.invokeExact( + indexPointer, + numQueries, + queriesPointer, + k, + searchParametersPointer, + distancesPointer, + idsPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_train$MH = + getHandle("faiss_Index_train", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS)); + + int faiss_Index_train(MemorySegment indexPointer, long size, MemorySegment docsPointer) { + try { + return (int) faiss_Index_train$MH.invokeExact(indexPointer, size, docsPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_ParameterSpace_free$MH = + getHandle("faiss_ParameterSpace_free", FunctionDescriptor.ofVoid(ADDRESS)); + + void faiss_ParameterSpace_free(MemorySegment parameterSpacePointer) { + try { + faiss_ParameterSpace_free$MH.invokeExact(parameterSpacePointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_ParameterSpace_new$MH = + getHandle("faiss_ParameterSpace_new", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + + int faiss_ParameterSpace_new(MemorySegment pointer) { + try { + return (int) faiss_ParameterSpace_new$MH.invokeExact(pointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_ParameterSpace_set_index_parameters$MH = + getHandle( + "faiss_ParameterSpace_set_index_parameters", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS)); + + int faiss_ParameterSpace_set_index_parameters( + MemorySegment parameterSpacePointer, + MemorySegment indexPointer, + MemorySegment descriptionPointer) { + try { + return (int) + faiss_ParameterSpace_set_index_parameters$MH.invokeExact( + parameterSpacePointer, indexPointer, descriptionPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_SearchParameters_free$MH = + getHandle("faiss_SearchParameters_free", FunctionDescriptor.ofVoid(ADDRESS)); + + void faiss_SearchParameters_free(MemorySegment searchParametersPointer) { + try { + faiss_SearchParameters_free$MH.invokeExact(searchParametersPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_SearchParameters_new$MH = + getHandle("faiss_SearchParameters_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + + int faiss_SearchParameters_new(MemorySegment pointer, MemorySegment idSelectorBitmapPointer) { + try { + return (int) faiss_SearchParameters_new$MH.invokeExact(pointer, idSelectorBitmapPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_index_factory$MH = + getHandle( + "faiss_index_factory", + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT)); + + int faiss_index_factory( + MemorySegment pointer, int dimension, MemorySegment description, int metric) { + try { + return (int) faiss_index_factory$MH.invokeExact(pointer, dimension, description, metric); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_read_index_custom$MH = + getHandle( + "faiss_read_index_custom", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS)); + + int faiss_read_index_custom( + MemorySegment customIOReaderPointer, int ioFlags, MemorySegment pointer) { + try { + return (int) faiss_read_index_custom$MH.invokeExact(customIOReaderPointer, ioFlags, pointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_write_index_custom$MH = + getHandle( + "faiss_write_index_custom", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); + + int faiss_write_index_custom( + MemorySegment indexPointer, MemorySegment customIOWriterPointer, int ioFlags) { + try { + return (int) + faiss_write_index_custom$MH.invokeExact(indexPointer, customIOWriterPointer, ioFlags); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + /** + * Exception used to rethrow handled Faiss errors in native code. + * + * @lucene.experimental + */ + static class Exception extends RuntimeException { + // See error codes defined in c_api/error_c.h + enum ErrorCode { + /// No error + OK(0), + /// Any exception other than Faiss or standard C++ library exceptions + UNKNOWN_EXCEPT(-1), + /// Faiss library exception + FAISS_EXCEPT(-2), + /// Standard C++ library exception + STD_EXCEPT(-4); + + private final int code; + + ErrorCode(int code) { + this.code = code; + } + + static ErrorCode fromCode(int code) { + return Arrays.stream(ErrorCode.values()) + .filter(errorCode -> errorCode.code == code) + .findFirst() + .orElseThrow(); + } + } + + private Exception(int code) { + super( + String.format( + Locale.ROOT, + "%s[%s(%d)]", + Exception.class.getName(), + ErrorCode.fromCode(code), + code)); + } + + static void handleException(int code) { + if (code < 0) { + throw new Exception(code); + } + } + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/LibFaissC.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/LibFaissC.java deleted file mode 100644 index d9a7baa8ef3c..000000000000 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/LibFaissC.java +++ /dev/null @@ -1,534 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.lucene.sandbox.codecs.faiss; - -import static java.lang.foreign.ValueLayout.ADDRESS; -import static java.lang.foreign.ValueLayout.JAVA_BYTE; -import static java.lang.foreign.ValueLayout.JAVA_FLOAT; -import static java.lang.foreign.ValueLayout.JAVA_INT; -import static java.lang.foreign.ValueLayout.JAVA_LONG; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - -import java.io.IOException; -import java.lang.foreign.Arena; -import java.lang.foreign.FunctionDescriptor; -import java.lang.foreign.Linker; -import java.lang.foreign.MemorySegment; -import java.lang.foreign.SymbolLookup; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.lang.invoke.MethodType; -import java.nio.ByteOrder; -import java.util.Arrays; -import java.util.Locale; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.KnnVectorValues; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.KnnCollector; -import org.apache.lucene.store.IndexInput; -import org.apache.lucene.store.IndexOutput; -import org.apache.lucene.util.Bits; -import org.apache.lucene.util.FixedBitSet; -import org.apache.lucene.util.hnsw.IntToIntFunction; - -/** - * Utility class to wrap necessary functions of the native C API of Faiss - * using Project Panama. - * - * @lucene.experimental - */ -@SuppressWarnings("restricted") // uses unsafe calls -final class LibFaissC { - // TODO: Use vectorized version where available - public static final String LIBRARY_NAME = "faiss_c"; - public static final String LIBRARY_VERSION = "1.11.0"; - - // See flags defined in c_api/index_io_c.h - static final int FAISS_IO_FLAG_MMAP = 1; - static final int FAISS_IO_FLAG_READ_ONLY = 2; - - private static final int BUFFER_SIZE = 256 * 1024 * 1024; // 256 MB - - static { - System.loadLibrary(LIBRARY_NAME); - checkLibraryVersion(); - } - - private LibFaissC() {} - - private static MemorySegment getUpcallStub( - Arena arena, MethodHandle target, FunctionDescriptor descriptor) { - return Linker.nativeLinker().upcallStub(target, descriptor, arena); - } - - private static MethodHandle getDowncallHandle( - String functionName, FunctionDescriptor descriptor) { - return Linker.nativeLinker() - .downcallHandle(SymbolLookup.loaderLookup().findOrThrow(functionName), descriptor); - } - - private static void checkLibraryVersion() { - MethodHandle getVersion = - getDowncallHandle("faiss_get_version", FunctionDescriptor.of(ADDRESS)); - - MemorySegment nativeString = call(getVersion); - String actualVersion = nativeString.reinterpret(Long.MAX_VALUE).getString(0); - - if (LIBRARY_VERSION.equals(actualVersion) == false) { - throw new UnsupportedOperationException( - String.format( - Locale.ROOT, - "Expected Faiss library version %s, found %s", - LIBRARY_VERSION, - actualVersion)); - } - } - - private static final MethodHandle FREE_INDEX = - getDowncallHandle("faiss_Index_free", FunctionDescriptor.ofVoid(ADDRESS)); - - public static void freeIndex(MemorySegment indexPointer) { - call(FREE_INDEX, indexPointer); - } - - private static final MethodHandle FREE_CUSTOM_IO_WRITER = - getDowncallHandle("faiss_CustomIOWriter_free", FunctionDescriptor.ofVoid(ADDRESS)); - - public static void freeCustomIOWriter(MemorySegment customIOWriterPointer) { - call(FREE_CUSTOM_IO_WRITER, customIOWriterPointer); - } - - private static final MethodHandle FREE_CUSTOM_IO_READER = - getDowncallHandle("faiss_CustomIOReader_free", FunctionDescriptor.ofVoid(ADDRESS)); - - public static void freeCustomIOReader(MemorySegment customIOReaderPointer) { - call(FREE_CUSTOM_IO_READER, customIOReaderPointer); - } - - private static final MethodHandle FREE_PARAMETER_SPACE = - getDowncallHandle("faiss_ParameterSpace_free", FunctionDescriptor.ofVoid(ADDRESS)); - - private static void freeParameterSpace(MemorySegment parameterSpacePointer) { - call(FREE_PARAMETER_SPACE, parameterSpacePointer); - } - - private static final MethodHandle FREE_ID_SELECTOR_BITMAP = - getDowncallHandle("faiss_IDSelectorBitmap_free", FunctionDescriptor.ofVoid(ADDRESS)); - - private static void freeIDSelectorBitmap(MemorySegment idSelectorBitmapPointer) { - call(FREE_ID_SELECTOR_BITMAP, idSelectorBitmapPointer); - } - - private static final MethodHandle FREE_SEARCH_PARAMETERS = - getDowncallHandle("faiss_SearchParameters_free", FunctionDescriptor.ofVoid(ADDRESS)); - - private static void freeSearchParameters(MemorySegment searchParametersPointer) { - call(FREE_SEARCH_PARAMETERS, searchParametersPointer); - } - - private static final MethodHandle INDEX_FACTORY = - getDowncallHandle( - "faiss_index_factory", - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT)); - - private static final MethodHandle PARAMETER_SPACE_NEW = - getDowncallHandle("faiss_ParameterSpace_new", FunctionDescriptor.of(JAVA_INT, ADDRESS)); - - private static final MethodHandle SET_INDEX_PARAMETERS = - getDowncallHandle( - "faiss_ParameterSpace_set_index_parameters", - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS)); - - private static final MethodHandle ID_SELECTOR_BITMAP_NEW = - getDowncallHandle( - "faiss_IDSelectorBitmap_new", - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS)); - - private static final MethodHandle SEARCH_PARAMETERS_NEW = - getDowncallHandle( - "faiss_SearchParameters_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); - - private static final MethodHandle INDEX_IS_TRAINED = - getDowncallHandle("faiss_Index_is_trained", FunctionDescriptor.of(JAVA_INT, ADDRESS)); - - private static final MethodHandle INDEX_TRAIN = - getDowncallHandle( - "faiss_Index_train", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS)); - - private static final MethodHandle INDEX_ADD_WITH_IDS = - getDowncallHandle( - "faiss_Index_add_with_ids", - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS)); - - public static MemorySegment createIndex( - String description, - String indexParams, - VectorSimilarityFunction function, - FloatVectorValues floatVectorValues, - IntToIntFunction oldToNewDocId) - throws IOException { - - try (Arena temp = Arena.ofConfined()) { - int size = floatVectorValues.size(); - int dimension = floatVectorValues.dimension(); - - // Mapped from faiss/MetricType.h - int metric = - switch (function) { - case DOT_PRODUCT -> 0; - case EUCLIDEAN -> 1; - case COSINE, MAXIMUM_INNER_PRODUCT -> - throw new UnsupportedOperationException("Metric type not supported"); - }; - - // Create an index - MemorySegment pointer = temp.allocate(ADDRESS); - callAndHandleError(INDEX_FACTORY, pointer, dimension, temp.allocateFrom(description), metric); - MemorySegment indexPointer = pointer.get(ADDRESS, 0); - - // Set index params - callAndHandleError(PARAMETER_SPACE_NEW, pointer); - MemorySegment parameterSpacePointer = - pointer - .get(ADDRESS, 0) - // Ensure timely cleanup - .reinterpret(temp, LibFaissC::freeParameterSpace); - - callAndHandleError( - SET_INDEX_PARAMETERS, - parameterSpacePointer, - indexPointer, - temp.allocateFrom(indexParams)); - - // TODO: Improve memory usage (with a tradeoff in performance) by batched indexing, see: - // - https://github.com/opensearch-project/k-NN/issues/1506 - // - https://github.com/opensearch-project/k-NN/issues/1938 - - // Allocate docs in native memory - MemorySegment docs = temp.allocate(JAVA_FLOAT, (long) size * dimension); - long docsOffset = 0; - long perDocByteSize = dimension * JAVA_FLOAT.byteSize(); - - // Allocate ids in native memory - MemorySegment ids = temp.allocate(JAVA_LONG, size); - int idsIndex = 0; - - KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); - for (int i = iterator.nextDoc(); i != NO_MORE_DOCS; i = iterator.nextDoc()) { - int id = oldToNewDocId.apply(i); - ids.setAtIndex(JAVA_LONG, idsIndex, id); - idsIndex++; - - float[] vector = floatVectorValues.vectorValue(iterator.index()); - MemorySegment.copy(vector, 0, docs, JAVA_FLOAT, docsOffset, vector.length); - docsOffset += perDocByteSize; - } - - // Train index - int isTrained = call(INDEX_IS_TRAINED, indexPointer); - if (isTrained == 0) { - callAndHandleError(INDEX_TRAIN, indexPointer, size, docs); - } - - // Add docs to index - callAndHandleError(INDEX_ADD_WITH_IDS, indexPointer, size, docs, ids); - - return indexPointer; - } - } - - @SuppressWarnings("unused") // called using a MethodHandle - private static long writeBytes( - IndexOutput output, MemorySegment inputPointer, long itemSize, long numItems) - throws IOException { - long size = itemSize * numItems; - inputPointer = inputPointer.reinterpret(size); - - if (size <= BUFFER_SIZE) { // simple case, avoid buffering - output.writeBytes(inputPointer.toArray(JAVA_BYTE), (int) size); - } else { // copy buffered number of bytes repeatedly - byte[] bytes = new byte[BUFFER_SIZE]; - for (long offset = 0; offset < size; offset += BUFFER_SIZE) { - int length = (int) Math.min(size - offset, BUFFER_SIZE); - MemorySegment.copy(inputPointer, JAVA_BYTE, offset, bytes, 0, length); - output.writeBytes(bytes, length); - } - } - return numItems; - } - - @SuppressWarnings("unused") // called using a MethodHandle - private static long readBytes( - IndexInput input, MemorySegment outputPointer, long itemSize, long numItems) - throws IOException { - long size = itemSize * numItems; - outputPointer = outputPointer.reinterpret(size); - - if (size <= BUFFER_SIZE) { // simple case, avoid buffering - byte[] bytes = new byte[(int) size]; - input.readBytes(bytes, 0, bytes.length); - MemorySegment.copy(bytes, 0, outputPointer, JAVA_BYTE, 0, bytes.length); - } else { // copy buffered number of bytes repeatedly - byte[] bytes = new byte[BUFFER_SIZE]; - for (long offset = 0; offset < size; offset += BUFFER_SIZE) { - int length = (int) Math.min(size - offset, BUFFER_SIZE); - input.readBytes(bytes, 0, length); - MemorySegment.copy(bytes, 0, outputPointer, JAVA_BYTE, offset, length); - } - } - return numItems; - } - - private static final MethodHandle WRITE_BYTES_HANDLE; - private static final MethodHandle READ_BYTES_HANDLE; - - static { - try { - WRITE_BYTES_HANDLE = - MethodHandles.lookup() - .findStatic( - LibFaissC.class, - "writeBytes", - MethodType.methodType( - long.class, IndexOutput.class, MemorySegment.class, long.class, long.class)); - - READ_BYTES_HANDLE = - MethodHandles.lookup() - .findStatic( - LibFaissC.class, - "readBytes", - MethodType.methodType( - long.class, IndexInput.class, MemorySegment.class, long.class, long.class)); - } catch (NoSuchMethodException | IllegalAccessException e) { - throw new RuntimeException(e); - } - } - - private static final MethodHandle CUSTOM_IO_WRITER_NEW = - getDowncallHandle( - "faiss_CustomIOWriter_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); - - private static final MethodHandle WRITE_INDEX_CUSTOM = - getDowncallHandle( - "faiss_write_index_custom", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); - - public static void indexWrite(MemorySegment indexPointer, IndexOutput output, int ioFlags) { - try (Arena temp = Arena.ofConfined()) { - MethodHandle writerHandle = WRITE_BYTES_HANDLE.bindTo(output); - MemorySegment writerStub = - getUpcallStub( - temp, writerHandle, FunctionDescriptor.of(JAVA_LONG, ADDRESS, JAVA_LONG, JAVA_LONG)); - - MemorySegment pointer = temp.allocate(ADDRESS); - callAndHandleError(CUSTOM_IO_WRITER_NEW, pointer, writerStub); - MemorySegment customIOWriterPointer = - pointer - .get(ADDRESS, 0) - // Ensure timely cleanup - .reinterpret(temp, LibFaissC::freeCustomIOWriter); - - callAndHandleError(WRITE_INDEX_CUSTOM, indexPointer, customIOWriterPointer, ioFlags); - } - } - - private static final MethodHandle CUSTOM_IO_READER_NEW = - getDowncallHandle( - "faiss_CustomIOReader_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); - - private static final MethodHandle READ_INDEX_CUSTOM = - getDowncallHandle( - "faiss_read_index_custom", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS)); - - public static MemorySegment indexRead(IndexInput input, int ioFlags) { - try (Arena temp = Arena.ofConfined()) { - MethodHandle readerHandle = READ_BYTES_HANDLE.bindTo(input); - MemorySegment readerStub = - getUpcallStub( - temp, readerHandle, FunctionDescriptor.of(JAVA_LONG, ADDRESS, JAVA_LONG, JAVA_LONG)); - - MemorySegment pointer = temp.allocate(ADDRESS); - callAndHandleError(CUSTOM_IO_READER_NEW, pointer, readerStub); - MemorySegment customIOReaderPointer = - pointer - .get(ADDRESS, 0) - // Ensure timely cleanup - .reinterpret(temp, LibFaissC::freeCustomIOReader); - - callAndHandleError(READ_INDEX_CUSTOM, customIOReaderPointer, ioFlags, pointer); - return pointer.get(ADDRESS, 0); - } - } - - private static final MethodHandle INDEX_SEARCH = - getDowncallHandle( - "faiss_Index_search", - FunctionDescriptor.of( - JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS)); - - private static final MethodHandle INDEX_SEARCH_WITH_PARAMS = - getDowncallHandle( - "faiss_Index_search_with_params", - FunctionDescriptor.of( - JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS, ADDRESS)); - - public static void indexSearch( - MemorySegment indexPointer, - VectorSimilarityFunction function, - float[] query, - KnnCollector knnCollector, - Bits acceptDocs) { - - try (Arena temp = Arena.ofConfined()) { - FixedBitSet fixedBitSet = - switch (acceptDocs) { - case null -> null; - case FixedBitSet bitSet -> bitSet; - // TODO: Add optimized case for SparseFixedBitSet - case Bits bits -> FixedBitSet.copyOf(bits); - }; - - // Allocate queries in native memory - MemorySegment queries = temp.allocateFrom(JAVA_FLOAT, query); - - // Faiss knn search - int k = knnCollector.k(); - MemorySegment distancesPointer = temp.allocate(JAVA_FLOAT, k); - MemorySegment idsPointer = temp.allocate(JAVA_LONG, k); - - MemorySegment localIndex = indexPointer.reinterpret(temp, null); - if (fixedBitSet == null) { - // Search without runtime filters - callAndHandleError(INDEX_SEARCH, localIndex, 1, queries, k, distancesPointer, idsPointer); - } else { - MemorySegment pointer = temp.allocate(ADDRESS); - - long[] bits = fixedBitSet.getBits(); - MemorySegment nativeBits = - // Use LITTLE_ENDIAN to convert long[] -> uint8_t* - temp.allocateFrom(JAVA_LONG.withOrder(ByteOrder.LITTLE_ENDIAN), bits); - - callAndHandleError(ID_SELECTOR_BITMAP_NEW, pointer, fixedBitSet.length(), nativeBits); - MemorySegment idSelectorBitmapPointer = - pointer - .get(ADDRESS, 0) - // Ensure timely cleanup - .reinterpret(temp, LibFaissC::freeIDSelectorBitmap); - - callAndHandleError(SEARCH_PARAMETERS_NEW, pointer, idSelectorBitmapPointer); - MemorySegment searchParametersPointer = - pointer - .get(ADDRESS, 0) - // Ensure timely cleanup - .reinterpret(temp, LibFaissC::freeSearchParameters); - - // Search with runtime filters - callAndHandleError( - INDEX_SEARCH_WITH_PARAMS, - localIndex, - 1, - queries, - k, - searchParametersPointer, - distancesPointer, - idsPointer); - } - - // Retrieve scores and ids - float[] distances = distancesPointer.toArray(JAVA_FLOAT); - long[] ids = idsPointer.toArray(JAVA_LONG); - - // Record hits - for (int i = 0; i < k; i++) { - // Not enough results - if (ids[i] == -1) { - break; - } - - // Scale Faiss distances to Lucene scores, see VectorSimilarityFunction.java - float score = - switch (function) { - case DOT_PRODUCT -> - // distance in Faiss === dotProduct in Lucene - Math.max((1 + distances[i]) / 2, 0); - - case EUCLIDEAN -> - // distance in Faiss === squareDistance in Lucene - 1 / (1 + distances[i]); - - case COSINE, MAXIMUM_INNER_PRODUCT -> - throw new UnsupportedOperationException("Metric type not supported"); - }; - - knnCollector.collect((int) ids[i], score); - } - } - } - - @SuppressWarnings("unchecked") - private static T call(MethodHandle handle, Object... args) { - try { - return (T) handle.invokeWithArguments(args); - } catch (Throwable e) { - throw new RuntimeException(e); - } - } - - private static void callAndHandleError(MethodHandle handle, Object... args) { - int returnCode = call(handle, args); - if (returnCode < 0) { - // TODO: Surface actual exception in a thread-safe manner? - throw new FaissException(returnCode); - } - } - - /** - * Exception used to rethrow handled Faiss errors in native code. - * - * @lucene.experimental - */ - public static class FaissException extends RuntimeException { - // See error codes defined in c_api/error_c.h - enum ErrorCode { - /// No error - OK(0), - /// Any exception other than Faiss or standard C++ library exceptions - UNKNOWN_EXCEPT(-1), - /// Faiss library exception - FAISS_EXCEPT(-2), - /// Standard C++ library exception - STD_EXCEPT(-4); - - private final int code; - - ErrorCode(int code) { - this.code = code; - } - - static ErrorCode fromCode(int code) { - return Arrays.stream(ErrorCode.values()) - .filter(errorCode -> errorCode.code == code) - .findFirst() - .orElseThrow(); - } - } - - public FaissException(int code) { - super(String.format(Locale.ROOT, "Faiss library ran into %s", ErrorCode.fromCode(code))); - } - } -} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java index e63fa3070f96..bd4a5b88d2aa 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java @@ -22,8 +22,7 @@ * org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat}. * *

To use this format: Install pytorch/faiss-cpu v{@value - * org.apache.lucene.sandbox.codecs.faiss.LibFaissC#LIBRARY_VERSION} from pytorch/faiss-cpu from Conda and place shared libraries (including * dependencies) on the {@code $LD_LIBRARY_PATH} environment variable or {@code -Djava.library.path} * JVM argument. @@ -39,9 +38,11 @@ *

  • Install micromamba (an open-source Conda * package manager) or similar *
  • Install dependencies using {@code micromamba create -n faiss-env -c pytorch -c conda-forge - * -y faiss-cpu=}{@value org.apache.lucene.sandbox.codecs.faiss.LibFaissC#LIBRARY_VERSION} + * -y faiss-cpu=}{@value org.apache.lucene.sandbox.codecs.faiss.FaissLibrary#VERSION} *
  • Activate environment using {@code micromamba activate faiss-env} *
  • Add shared libraries to runtime using {@code export LD_LIBRARY_PATH=$CONDA_PREFIX/lib} + * (verify that the {@value org.apache.lucene.sandbox.codecs.faiss.FaissLibrary#NAME} library + * is present here) *
  • And you're good to go! (add the {@code -Dtests.faiss.run=true} JVM argument to ensure Faiss * tests are run) * diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java index 4a3fb3661e6d..f66eae680db1 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java @@ -48,21 +48,21 @@ public class TestFaissKnnVectorsFormat extends BaseKnnVectorsFormatTestCase { private static final VectorSimilarityFunction[] SUPPORTED_FUNCTIONS = {DOT_PRODUCT, EUCLIDEAN}; @BeforeClass - public static void maybeSuppress() throws ClassNotFoundException { + public static void maybeSuppress() { // Explicitly run tests if (Boolean.getBoolean(FAISS_RUN_TESTS)) { return; } // Otherwise check if dependencies are present - boolean faissLibraryPresent; + boolean dependenciesPresent; try { - Class.forName("org.apache.lucene.sandbox.codecs.faiss.LibFaissC"); - faissLibraryPresent = true; - } catch (UnsatisfiedLinkError _) { - faissLibraryPresent = false; + FaissLibrary _ = FaissLibrary.INSTANCE; + dependenciesPresent = true; + } catch (LinkageError _) { + dependenciesPresent = false; } - assumeTrue("Native libraries present", faissLibraryPresent); + assumeTrue("Dependencies present", dependenciesPresent); } @Override