diff --git a/.github/workflows/run-special-checks-sandbox.yml b/.github/workflows/run-special-checks-sandbox.yml new file mode 100644 index 000000000000..1d18b3a060fe --- /dev/null +++ b/.github/workflows/run-special-checks-sandbox.yml @@ -0,0 +1,58 @@ +name: "Run special checks: module lucene/sandbox" + +on: + workflow_dispatch: + + pull_request: + branches: + - '*' + + push: + branches: + - 'main' + - 'branch_10x' + +jobs: + faiss-tests: + name: tests for the Faiss codec (v${{ matrix.faiss-version }} with JDK ${{ matrix.java }} on ${{ matrix.os }}) + timeout-minutes: 15 + + strategy: + matrix: + os: [ ubuntu-latest ] + java: [ '24' ] + faiss-version: [ '1.11.0' ] + + runs-on: ${{ matrix.os }} + + steps: + - name: Install Mamba + uses: conda-incubator/setup-miniconda@835234971496cad1653abb28a638a281cf32541f #v3.2.0 + with: + miniforge-version: 'latest' + auto-activate-base: 'false' + activate-environment: 'faiss-env' + # TODO: Use only conda-forge if possible, see https://github.com/conda-forge/faiss-split-feedstock/pull/88 + channels: 'pytorch,conda-forge' + conda-remove-defaults: 'true' + + - name: Install Faiss + run: mamba install faiss-cpu=${{ matrix.faiss-version }} + + - name: Checkout Lucene + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Prepare Lucene workspace + uses: ./.github/actions/prepare-for-build + + - name: Run tests for Faiss codec + run: > + LD_LIBRARY_PATH=$CONDA_PREFIX/lib + ./gradlew -p lucene/sandbox + -Dtests.faiss.run=true + test + --tests "org.apache.lucene.sandbox.codecs.faiss.*" + + defaults: + run: + shell: bash -leo pipefail {0} diff --git a/build-tools/build-infra/src/main/groovy/lucene.java.tests-and-randomization.gradle b/build-tools/build-infra/src/main/groovy/lucene.java.tests-and-randomization.gradle index 82517cd6464a..a9aef687bb33 100644 --- a/build-tools/build-infra/src/main/groovy/lucene.java.tests-and-randomization.gradle +++ b/build-tools/build-infra/src/main/groovy/lucene.java.tests-and-randomization.gradle @@ -147,6 +147,9 @@ buildOptions.addOption("tests.file.encoding", "Sets the default file.encoding on ]) }) +buildOptions.addBooleanOption("tests.faiss.run", "Explicitly run tests for the Faiss codec.", false) +optionsInheritedAsProperties += ["tests.faiss.run"] + // TODO: do we still use these? // Test data file used. // [propName: 'tests.linedocsfile', value: 'europarl.lines.txt.gz', description: "Test data file path."], diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index db4e934cfcf5..e276655cff2b 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -30,6 +30,8 @@ New Features --------------------- * GITHUB#14097: Binary partitioning merge policy over float-valued vector field. (Mike Sokolov) +* GITHUB#14178: Add a Faiss-based vector format in the sandbox module. (Kaival Parikh) + Improvements --------------------- diff --git a/lucene/sandbox/src/java/module-info.java b/lucene/sandbox/src/java/module-info.java index d79a150fea3e..ee9be3227de2 100644 --- a/lucene/sandbox/src/java/module-info.java +++ b/lucene/sandbox/src/java/module-info.java @@ -22,6 +22,7 @@ requires org.apache.lucene.facet; exports org.apache.lucene.payloads; + exports org.apache.lucene.sandbox.codecs.faiss; exports org.apache.lucene.sandbox.codecs.idversion; exports org.apache.lucene.sandbox.codecs.quantization; exports org.apache.lucene.sandbox.document; @@ -39,4 +40,6 @@ provides org.apache.lucene.codecs.PostingsFormat with org.apache.lucene.sandbox.codecs.idversion.IDVersionPostingsFormat; + provides org.apache.lucene.codecs.KnnVectorsFormat with + org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat; } diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java new file mode 100644 index 000000000000..83beae607dc5 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_MAX_CONN; + +import java.io.IOException; +import java.util.Locale; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +/** + * A Faiss-based format to create and search vector indexes, using {@link LibFaissC} to interact + * with the native library. + * + *

The Faiss index is configured using its flexible index factory, which + * allows creating arbitrary indexes by "describing" them. These indexes can be tuned by setting + * relevant parameters. + * + *

A separate Faiss index is created per-segment, and uses the following files: + * + *

+ * + *

Note: Set the {@code $OMP_NUM_THREADS} environment variable to control internal + * threading. + * + *

TODO: There is no guarantee of backwards compatibility! + * + * @lucene.experimental + */ +public final class FaissKnnVectorsFormat extends KnnVectorsFormat { + public static final String NAME = FaissKnnVectorsFormat.class.getSimpleName(); + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + static final String META_CODEC_NAME = NAME + "Meta"; + static final String DATA_CODEC_NAME = NAME + "Data"; + static final String META_EXTENSION = "faissm"; + static final String DATA_EXTENSION = "faissd"; + + private final String description; + private final String indexParams; + private final FlatVectorsFormat rawVectorsFormat; + + /** + * Constructs an HNSW-based format using default {@code maxConn}={@value + * org.apache.lucene.util.hnsw.HnswGraphBuilder#DEFAULT_MAX_CONN} and {@code beamWidth}={@value + * org.apache.lucene.util.hnsw.HnswGraphBuilder#DEFAULT_BEAM_WIDTH}. + */ + public FaissKnnVectorsFormat() { + this( + String.format(Locale.ROOT, "IDMap,HNSW%d", DEFAULT_MAX_CONN), + String.format(Locale.ROOT, "efConstruction=%d", DEFAULT_BEAM_WIDTH)); + } + + /** + * Constructs a format using the specified index factory string and index parameters (see class + * docs for more information). + * + * @param description the index factory string to initialize Faiss indexes. + * @param indexParams the index params to set on Faiss indexes. + */ + public FaissKnnVectorsFormat(String description, String indexParams) { + super(NAME); + this.description = description; + this.indexParams = indexParams; + this.rawVectorsFormat = + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new FaissKnnVectorsWriter( + description, indexParams, state, rawVectorsFormat.fieldsWriter(state)); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new FaissKnnVectorsReader(state, rawVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return DEFAULT_MAX_DIMENSIONS; + } + + @Override + public String toString() { + return String.format( + Locale.ROOT, "%s(description=%s indexParams=%s)", NAME, description, indexParams); + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java new file mode 100644 index 000000000000..42a95145fbd4 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_START; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_MMAP; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_READ_ONLY; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexRead; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexSearch; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; + +/** + * Read per-segment Faiss indexes and associated metadata. + * + * @lucene.experimental + */ +final class FaissKnnVectorsReader extends KnnVectorsReader { + private final FlatVectorsReader rawVectorsReader; + private final IndexInput data; + private final Map indexMap; + private final Arena arena; + private boolean closed; + + public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) + throws IOException { + this.rawVectorsReader = rawVectorsReader; + this.indexMap = new HashMap<>(); + this.arena = Arena.ofShared(); + this.closed = false; + + List fieldMetaList = new ArrayList<>(); + String metaFileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, META_EXTENSION); + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + int versionMeta = -1; + try { + versionMeta = + CodecUtil.checkIndexHeader( + meta, + META_CODEC_NAME, + VERSION_START, + VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + + FieldMeta fieldMeta; + while ((fieldMeta = parseNextField(meta, state)) != null) { + fieldMetaList.add(fieldMeta); + } + } catch (Throwable t) { + priorE = t; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + + String dataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, DATA_EXTENSION); + this.data = + state.directory.openInput( + dataFileName, state.context.withHints(FileTypeHint.DATA, DataAccessHint.RANDOM)); + + int versionData = + CodecUtil.checkIndexHeader( + this.data, + DATA_CODEC_NAME, + VERSION_START, + VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + if (versionMeta != versionData) { + throw new CorruptIndexException( + String.format( + Locale.ROOT, + "Format versions mismatch (meta=%d, data=%d)", + versionMeta, + versionData), + data); + } + CodecUtil.retrieveChecksum(data); + + for (FieldMeta fieldMeta : fieldMetaList) { + if (indexMap.put(fieldMeta.fieldInfo.name, loadField(data, arena, fieldMeta)) != null) { + throw new CorruptIndexException("Duplicate field: " + fieldMeta.fieldInfo.name, meta); + } + } + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, this); + throw t; + } + } + + private static FieldMeta parseNextField(IndexInput meta, SegmentReadState state) + throws IOException { + int fieldNumber = meta.readInt(); + if (fieldNumber == -1) { + return null; + } + + FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldNumber); + if (fieldInfo == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + + long dataOffset = meta.readLong(); + long dataLength = meta.readLong(); + + return new FieldMeta(fieldInfo, dataOffset, dataLength); + } + + private static IndexEntry loadField(IndexInput data, Arena arena, FieldMeta fieldMeta) + throws IOException { + int ioFlags = FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY; + + // Read index into memory + MemorySegment indexPointer = + indexRead(data.slice(fieldMeta.fieldInfo.name, fieldMeta.offset, fieldMeta.length), ioFlags) + // Ensure timely cleanup + .reinterpret(arena, LibFaissC::freeIndex); + + return new IndexEntry(indexPointer, fieldMeta.fieldInfo.getVectorSimilarityFunction()); + } + + @Override + public void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + // TODO: Evaluate if we need an explicit check for validity of Faiss indexes + CodecUtil.checksumEntireFile(data); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + return rawVectorsReader.getFloatVectorValues(field); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) { + // TODO: Support using SQ8 quantization, see: + // - https://github.com/opensearch-project/k-NN/pull/2425 + throw new UnsupportedOperationException("Byte vectors not supported"); + } + + @Override + public void search(String field, float[] vector, KnnCollector knnCollector, Bits acceptDocs) { + IndexEntry entry = indexMap.get(field); + if (entry != null) { + indexSearch(entry.indexPointer, entry.function, vector, knnCollector, acceptDocs); + } + } + + @Override + public void search(String field, byte[] vector, KnnCollector knnCollector, Bits acceptDocs) { + // TODO: Support using SQ8 quantization, see: + // - https://github.com/opensearch-project/k-NN/pull/2425 + throw new UnsupportedOperationException("Byte vectors not supported"); + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + // TODO: How to estimate Faiss usage? + return rawVectorsReader.getOffHeapByteSize(fieldInfo); + } + + @Override + public void close() throws IOException { + if (closed == false) { + closed = true; + IOUtils.close(rawVectorsReader, arena::close, data, indexMap::clear); + } + } + + private record FieldMeta(FieldInfo fieldInfo, long offset, long length) {} + + private record IndexEntry(MemorySegment indexPointer, VectorSimilarityFunction function) {} +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java new file mode 100644 index 000000000000..0336e85e607e --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_MMAP; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_READ_ONLY; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.createIndex; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexWrite; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSet; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.hnsw.IntToIntFunction; + +/** + * Write per-segment Faiss indexes and associated metadata. + * + * @lucene.experimental + */ +final class FaissKnnVectorsWriter extends KnnVectorsWriter { + private final String description, indexParams; + private final FlatVectorsWriter rawVectorsWriter; + private final IndexOutput meta, data; + private final Map> rawFields; + private boolean finished; + + public FaissKnnVectorsWriter( + String description, + String indexParams, + SegmentWriteState state, + FlatVectorsWriter rawVectorsWriter) + throws IOException { + + this.description = description; + this.indexParams = indexParams; + this.rawVectorsWriter = rawVectorsWriter; + this.rawFields = new HashMap<>(); + this.finished = false; + + try { + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, META_EXTENSION); + this.meta = state.directory.createOutput(metaFileName, state.context); + CodecUtil.writeIndexHeader( + this.meta, + META_CODEC_NAME, + VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + + String dataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, DATA_EXTENSION); + this.data = state.directory.createOutput(dataFileName, state.context); + CodecUtil.writeIndexHeader( + this.data, + DATA_CODEC_NAME, + VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, this); + throw t; + } + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + rawVectorsWriter.mergeOneField(fieldInfo, mergeState); + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> + // TODO: Support using SQ8 quantization, see: + // - https://github.com/opensearch-project/k-NN/pull/2425 + throw new UnsupportedOperationException("Byte vectors not supported"); + case FLOAT32 -> { + FloatVectorValues merged = + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + writeFloatField(fieldInfo, merged, doc -> doc); + } + } + } + + @Override + public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter rawFieldVectorsWriter = rawVectorsWriter.addField(fieldInfo); + rawFields.put(fieldInfo, rawFieldVectorsWriter); + return rawFieldVectorsWriter; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorsWriter.flush(maxDoc, sortMap); + for (Map.Entry> entry : rawFields.entrySet()) { + FieldInfo fieldInfo = entry.getKey(); + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> + // TODO: Support using SQ8 quantization, see: + // - https://github.com/opensearch-project/k-NN/pull/2425 + throw new UnsupportedOperationException("Byte vectors not supported"); + + case FLOAT32 -> { + @SuppressWarnings("unchecked") + FlatFieldVectorsWriter rawWriter = + (FlatFieldVectorsWriter) entry.getValue(); + + List vectors = rawWriter.getVectors(); + int dimension = fieldInfo.getVectorDimension(); + DocIdSet docIdSet = rawWriter.getDocsWithFieldSet(); + + writeFloatField( + fieldInfo, + new BufferedFloatVectorValues(vectors, dimension, docIdSet), + (sortMap != null) ? sortMap::oldToNew : doc -> doc); + } + } + } + } + + private void writeFloatField( + FieldInfo fieldInfo, FloatVectorValues floatVectorValues, IntToIntFunction oldToNewDocId) + throws IOException { + int number = fieldInfo.number; + meta.writeInt(number); + + // Write index to temp file and deallocate from memory + try (Arena temp = Arena.ofConfined()) { + VectorSimilarityFunction function = fieldInfo.getVectorSimilarityFunction(); + MemorySegment indexPointer = + createIndex(description, indexParams, function, floatVectorValues, oldToNewDocId) + // Ensure timely cleanup + .reinterpret(temp, LibFaissC::freeIndex); + + int ioFlags = FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY; + + // Write index + long dataOffset = data.getFilePointer(); + indexWrite(indexPointer, data, ioFlags); + long dataLength = data.getFilePointer() - dataOffset; + + meta.writeLong(dataOffset); + meta.writeLong(dataLength); + } + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("Already finished"); + } + finished = true; + + rawVectorsWriter.finish(); + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + CodecUtil.writeFooter(data); + } + + @Override + public void close() throws IOException { + IOUtils.close(rawVectorsWriter, meta, data); + } + + @Override + public long ramBytesUsed() { + // TODO: How to estimate Faiss usage? + return rawVectorsWriter.ramBytesUsed(); + } + + private static class BufferedFloatVectorValues extends FloatVectorValues { + private final List floats; + private final int dimension; + private final DocIdSet docIdSet; + + public BufferedFloatVectorValues(List floats, int dimension, DocIdSet docIdSet) { + this.floats = floats; + this.dimension = dimension; + this.docIdSet = docIdSet; + } + + @Override + public float[] vectorValue(int ord) { + return floats.get(ord); + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return floats.size(); + } + + @Override + public FloatVectorValues copy() { + return new BufferedFloatVectorValues(floats, dimension, docIdSet); + } + + @Override + public DocIndexIterator iterator() { + return fromDISI(docIdSet.iterator()); + } + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/LibFaissC.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/LibFaissC.java new file mode 100644 index 000000000000..c521c4c20108 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/LibFaissC.java @@ -0,0 +1,548 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_FLOAT; +import static java.lang.foreign.ValueLayout.JAVA_INT; +import static java.lang.foreign.ValueLayout.JAVA_LONG; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.FunctionDescriptor; +import java.lang.foreign.Linker; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SymbolLookup; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.nio.LongBuffer; +import java.util.Arrays; +import java.util.Locale; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.hnsw.IntToIntFunction; + +/** + * Utility class to wrap necessary functions of the native C API of Faiss + * using Project Panama. + * + * @lucene.experimental + */ +final class LibFaissC { + // TODO: Use vectorized version where available + public static final String LIBRARY_NAME = "faiss_c"; + public static final String LIBRARY_VERSION = "1.11.0"; + + // See flags defined in c_api/index_io_c.h + static final int FAISS_IO_FLAG_MMAP = 1; + static final int FAISS_IO_FLAG_READ_ONLY = 2; + + private static final int BUFFER_SIZE = 256 * 1024 * 1024; // 256 MB + + static { + System.loadLibrary(LIBRARY_NAME); + checkLibraryVersion(); + } + + private LibFaissC() {} + + private static MemorySegment getUpcallStub( + Arena arena, MethodHandle target, FunctionDescriptor descriptor) { + return Linker.nativeLinker().upcallStub(target, descriptor, arena); + } + + private static MethodHandle getDowncallHandle( + String functionName, FunctionDescriptor descriptor) { + return Linker.nativeLinker() + .downcallHandle(SymbolLookup.loaderLookup().findOrThrow(functionName), descriptor); + } + + private static void checkLibraryVersion() { + MethodHandle getVersion = + getDowncallHandle("faiss_get_version", FunctionDescriptor.of(ADDRESS)); + + MemorySegment nativeString = call(getVersion); + String actualVersion = nativeString.reinterpret(Long.MAX_VALUE).getString(0); + + if (LIBRARY_VERSION.equals(actualVersion) == false) { + throw new UnsupportedOperationException( + String.format( + Locale.ROOT, + "Expected Faiss library version %s, found %s", + LIBRARY_VERSION, + actualVersion)); + } + } + + private static final MethodHandle FREE_INDEX = + getDowncallHandle("faiss_Index_free", FunctionDescriptor.ofVoid(ADDRESS)); + + public static void freeIndex(MemorySegment indexPointer) { + call(FREE_INDEX, indexPointer); + } + + private static final MethodHandle FREE_CUSTOM_IO_WRITER = + getDowncallHandle("faiss_CustomIOWriter_free", FunctionDescriptor.ofVoid(ADDRESS)); + + public static void freeCustomIOWriter(MemorySegment customIOWriterPointer) { + call(FREE_CUSTOM_IO_WRITER, customIOWriterPointer); + } + + private static final MethodHandle FREE_CUSTOM_IO_READER = + getDowncallHandle("faiss_CustomIOReader_free", FunctionDescriptor.ofVoid(ADDRESS)); + + public static void freeCustomIOReader(MemorySegment customIOReaderPointer) { + call(FREE_CUSTOM_IO_READER, customIOReaderPointer); + } + + private static final MethodHandle FREE_PARAMETER_SPACE = + getDowncallHandle("faiss_ParameterSpace_free", FunctionDescriptor.ofVoid(ADDRESS)); + + private static void freeParameterSpace(MemorySegment parameterSpacePointer) { + call(FREE_PARAMETER_SPACE, parameterSpacePointer); + } + + private static final MethodHandle FREE_ID_SELECTOR_BITMAP = + getDowncallHandle("faiss_IDSelectorBitmap_free", FunctionDescriptor.ofVoid(ADDRESS)); + + private static void freeIDSelectorBitmap(MemorySegment idSelectorBitmapPointer) { + call(FREE_ID_SELECTOR_BITMAP, idSelectorBitmapPointer); + } + + private static final MethodHandle FREE_SEARCH_PARAMETERS = + getDowncallHandle("faiss_SearchParameters_free", FunctionDescriptor.ofVoid(ADDRESS)); + + private static void freeSearchParameters(MemorySegment searchParametersPointer) { + call(FREE_SEARCH_PARAMETERS, searchParametersPointer); + } + + private static final MethodHandle INDEX_FACTORY = + getDowncallHandle( + "faiss_index_factory", + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT)); + + private static final MethodHandle PARAMETER_SPACE_NEW = + getDowncallHandle("faiss_ParameterSpace_new", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + + private static final MethodHandle SET_INDEX_PARAMETERS = + getDowncallHandle( + "faiss_ParameterSpace_set_index_parameters", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS)); + + private static final MethodHandle ID_SELECTOR_BITMAP_NEW = + getDowncallHandle( + "faiss_IDSelectorBitmap_new", + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS)); + + private static final MethodHandle SEARCH_PARAMETERS_NEW = + getDowncallHandle( + "faiss_SearchParameters_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + + private static final MethodHandle INDEX_IS_TRAINED = + getDowncallHandle("faiss_Index_is_trained", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + + private static final MethodHandle INDEX_TRAIN = + getDowncallHandle( + "faiss_Index_train", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS)); + + private static final MethodHandle INDEX_ADD_WITH_IDS = + getDowncallHandle( + "faiss_Index_add_with_ids", + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS)); + + public static MemorySegment createIndex( + String description, + String indexParams, + VectorSimilarityFunction function, + FloatVectorValues floatVectorValues, + IntToIntFunction oldToNewDocId) + throws IOException { + + try (Arena temp = Arena.ofConfined()) { + int size = floatVectorValues.size(); + int dimension = floatVectorValues.dimension(); + + // Mapped from faiss/MetricType.h + int metric = + switch (function) { + case DOT_PRODUCT -> 0; + case EUCLIDEAN -> 1; + case COSINE, MAXIMUM_INNER_PRODUCT -> + throw new UnsupportedOperationException("Metric type not supported"); + }; + + // Create an index + MemorySegment pointer = temp.allocate(ADDRESS); + callAndHandleError(INDEX_FACTORY, pointer, dimension, temp.allocateFrom(description), metric); + MemorySegment indexPointer = pointer.get(ADDRESS, 0); + + // Set index params + callAndHandleError(PARAMETER_SPACE_NEW, pointer); + MemorySegment parameterSpacePointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, LibFaissC::freeParameterSpace); + + callAndHandleError( + SET_INDEX_PARAMETERS, + parameterSpacePointer, + indexPointer, + temp.allocateFrom(indexParams)); + + // TODO: Improve memory usage (with a tradeoff in performance) by batched indexing, see: + // - https://github.com/opensearch-project/k-NN/issues/1506 + // - https://github.com/opensearch-project/k-NN/issues/1938 + + // Allocate docs in native memory + MemorySegment docs = temp.allocate(JAVA_FLOAT, (long) size * dimension); + FloatBuffer docsBuffer = docs.asByteBuffer().order(ByteOrder.nativeOrder()).asFloatBuffer(); + + // Allocate ids in native memory + MemorySegment ids = temp.allocate(JAVA_LONG, size); + LongBuffer idsBuffer = ids.asByteBuffer().order(ByteOrder.nativeOrder()).asLongBuffer(); + + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int i = iterator.nextDoc(); i != NO_MORE_DOCS; i = iterator.nextDoc()) { + idsBuffer.put(oldToNewDocId.apply(i)); + docsBuffer.put(floatVectorValues.vectorValue(iterator.index())); + } + + // Train index + int isTrained = call(INDEX_IS_TRAINED, indexPointer); + if (isTrained == 0) { + callAndHandleError(INDEX_TRAIN, indexPointer, size, docs); + } + + // Add docs to index + callAndHandleError(INDEX_ADD_WITH_IDS, indexPointer, size, docs, ids); + + return indexPointer; + } + } + + @SuppressWarnings("unused") // called using a MethodHandle + private static long writeBytes( + IndexOutput output, MemorySegment inputPointer, long itemSize, long numItems) + throws IOException { + long size = itemSize * numItems; + inputPointer = inputPointer.reinterpret(size); + + if (size <= BUFFER_SIZE) { // simple case, avoid buffering + byte[] bytes = new byte[(int) size]; + inputPointer.asSlice(0, size).asByteBuffer().order(ByteOrder.nativeOrder()).get(bytes); + output.writeBytes(bytes, bytes.length); + } else { // copy buffered number of bytes repeatedly + byte[] bytes = new byte[BUFFER_SIZE]; + for (long offset = 0; offset < size; offset += BUFFER_SIZE) { + int length = (int) Math.min(size - offset, BUFFER_SIZE); + inputPointer + .asSlice(offset, length) + .asByteBuffer() + .order(ByteOrder.nativeOrder()) + .get(bytes, 0, length); + output.writeBytes(bytes, length); + } + } + return numItems; + } + + @SuppressWarnings("unused") // called using a MethodHandle + private static long readBytes( + IndexInput input, MemorySegment outputPointer, long itemSize, long numItems) + throws IOException { + long size = itemSize * numItems; + outputPointer = outputPointer.reinterpret(size); + + if (size <= BUFFER_SIZE) { // simple case, avoid buffering + byte[] bytes = new byte[(int) size]; + input.readBytes(bytes, 0, bytes.length); + outputPointer + .asSlice(0, bytes.length) + .asByteBuffer() + .order(ByteOrder.nativeOrder()) + .put(bytes); + } else { // copy buffered number of bytes repeatedly + byte[] bytes = new byte[BUFFER_SIZE]; + for (long offset = 0; offset < size; offset += BUFFER_SIZE) { + int length = (int) Math.min(size - offset, BUFFER_SIZE); + input.readBytes(bytes, 0, length); + outputPointer + .asSlice(offset, length) + .asByteBuffer() + .order(ByteOrder.nativeOrder()) + .put(bytes, 0, length); + } + } + return numItems; + } + + private static final MethodHandle WRITE_BYTES_HANDLE; + private static final MethodHandle READ_BYTES_HANDLE; + + static { + try { + WRITE_BYTES_HANDLE = + MethodHandles.lookup() + .findStatic( + LibFaissC.class, + "writeBytes", + MethodType.methodType( + long.class, IndexOutput.class, MemorySegment.class, long.class, long.class)); + + READ_BYTES_HANDLE = + MethodHandles.lookup() + .findStatic( + LibFaissC.class, + "readBytes", + MethodType.methodType( + long.class, IndexInput.class, MemorySegment.class, long.class, long.class)); + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + private static final MethodHandle CUSTOM_IO_WRITER_NEW = + getDowncallHandle( + "faiss_CustomIOWriter_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + + private static final MethodHandle WRITE_INDEX_CUSTOM = + getDowncallHandle( + "faiss_write_index_custom", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); + + public static void indexWrite(MemorySegment indexPointer, IndexOutput output, int ioFlags) { + try (Arena temp = Arena.ofConfined()) { + MethodHandle writerHandle = WRITE_BYTES_HANDLE.bindTo(output); + MemorySegment writerStub = + getUpcallStub( + temp, writerHandle, FunctionDescriptor.of(JAVA_LONG, ADDRESS, JAVA_LONG, JAVA_LONG)); + + MemorySegment pointer = temp.allocate(ADDRESS); + callAndHandleError(CUSTOM_IO_WRITER_NEW, pointer, writerStub); + MemorySegment customIOWriterPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, LibFaissC::freeCustomIOWriter); + + callAndHandleError(WRITE_INDEX_CUSTOM, indexPointer, customIOWriterPointer, ioFlags); + } + } + + private static final MethodHandle CUSTOM_IO_READER_NEW = + getDowncallHandle( + "faiss_CustomIOReader_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + + private static final MethodHandle READ_INDEX_CUSTOM = + getDowncallHandle( + "faiss_read_index_custom", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS)); + + public static MemorySegment indexRead(IndexInput input, int ioFlags) { + try (Arena temp = Arena.ofConfined()) { + MethodHandle readerHandle = READ_BYTES_HANDLE.bindTo(input); + MemorySegment readerStub = + getUpcallStub( + temp, readerHandle, FunctionDescriptor.of(JAVA_LONG, ADDRESS, JAVA_LONG, JAVA_LONG)); + + MemorySegment pointer = temp.allocate(ADDRESS); + callAndHandleError(CUSTOM_IO_READER_NEW, pointer, readerStub); + MemorySegment customIOReaderPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, LibFaissC::freeCustomIOReader); + + callAndHandleError(READ_INDEX_CUSTOM, customIOReaderPointer, ioFlags, pointer); + return pointer.get(ADDRESS, 0); + } + } + + private static final MethodHandle INDEX_SEARCH = + getDowncallHandle( + "faiss_Index_search", + FunctionDescriptor.of( + JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS)); + + private static final MethodHandle INDEX_SEARCH_WITH_PARAMS = + getDowncallHandle( + "faiss_Index_search_with_params", + FunctionDescriptor.of( + JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS, ADDRESS)); + + public static void indexSearch( + MemorySegment indexPointer, + VectorSimilarityFunction function, + float[] query, + KnnCollector knnCollector, + Bits acceptDocs) { + + try (Arena temp = Arena.ofConfined()) { + FixedBitSet fixedBitSet = + switch (acceptDocs) { + case null -> null; + case FixedBitSet bitSet -> bitSet; + // TODO: Add optimized case for SparseFixedBitSet + case Bits bits -> FixedBitSet.copyOf(bits); + }; + + // Allocate queries in native memory + MemorySegment queries = temp.allocate(JAVA_FLOAT, query.length); + queries.asByteBuffer().order(ByteOrder.nativeOrder()).asFloatBuffer().put(query); + + // Faiss knn search + int k = knnCollector.k(); + MemorySegment distancesPointer = temp.allocate(JAVA_FLOAT, k); + MemorySegment idsPointer = temp.allocate(JAVA_LONG, k); + + MemorySegment localIndex = indexPointer.reinterpret(temp, null); + if (fixedBitSet == null) { + // Search without runtime filters + callAndHandleError(INDEX_SEARCH, localIndex, 1, queries, k, distancesPointer, idsPointer); + } else { + MemorySegment pointer = temp.allocate(ADDRESS); + + long[] bits = fixedBitSet.getBits(); + MemorySegment nativeBits = temp.allocate(JAVA_LONG, bits.length); + + // Use LITTLE_ENDIAN to convert long[] -> uint8_t* + nativeBits.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asLongBuffer().put(bits); + + callAndHandleError(ID_SELECTOR_BITMAP_NEW, pointer, fixedBitSet.length(), nativeBits); + MemorySegment idSelectorBitmapPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, LibFaissC::freeIDSelectorBitmap); + + callAndHandleError(SEARCH_PARAMETERS_NEW, pointer, idSelectorBitmapPointer); + MemorySegment searchParametersPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, LibFaissC::freeSearchParameters); + + // Search with runtime filters + callAndHandleError( + INDEX_SEARCH_WITH_PARAMS, + localIndex, + 1, + queries, + k, + searchParametersPointer, + distancesPointer, + idsPointer); + } + + // Retrieve scores + float[] distances = new float[k]; + distancesPointer.asByteBuffer().order(ByteOrder.nativeOrder()).asFloatBuffer().get(distances); + + // Retrieve ids + long[] ids = new long[k]; + idsPointer.asByteBuffer().order(ByteOrder.nativeOrder()).asLongBuffer().get(ids); + + // Record hits + for (int i = 0; i < k; i++) { + // Not enough results + if (ids[i] == -1) { + break; + } + + // Scale Faiss distances to Lucene scores, see VectorSimilarityFunction.java + float score = + switch (function) { + case DOT_PRODUCT -> + // distance in Faiss === dotProduct in Lucene + Math.max((1 + distances[i]) / 2, 0); + + case EUCLIDEAN -> + // distance in Faiss === squareDistance in Lucene + 1 / (1 + distances[i]); + + case COSINE, MAXIMUM_INNER_PRODUCT -> + throw new UnsupportedOperationException("Metric type not supported"); + }; + + knnCollector.collect((int) ids[i], score); + } + } + } + + @SuppressWarnings("unchecked") + private static T call(MethodHandle handle, Object... args) { + try { + return (T) handle.invokeWithArguments(args); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + private static void callAndHandleError(MethodHandle handle, Object... args) { + int returnCode = call(handle, args); + if (returnCode < 0) { + // TODO: Surface actual exception in a thread-safe manner? + throw new FaissException(returnCode); + } + } + + /** + * Exception used to rethrow handled Faiss errors in native code. + * + * @lucene.experimental + */ + public static class FaissException extends RuntimeException { + // See error codes defined in c_api/error_c.h + enum ErrorCode { + /// No error + OK(0), + /// Any exception other than Faiss or standard C++ library exceptions + UNKNOWN_EXCEPT(-1), + /// Faiss library exception + FAISS_EXCEPT(-2), + /// Standard C++ library exception + STD_EXCEPT(-4); + + private final int code; + + ErrorCode(int code) { + this.code = code; + } + + static ErrorCode fromCode(int code) { + return Arrays.stream(ErrorCode.values()) + .filter(errorCode -> errorCode.code == code) + .findFirst() + .orElseThrow(); + } + } + + public FaissException(int code) { + super(String.format(Locale.ROOT, "Faiss library ran into %s", ErrorCode.fromCode(code))); + } + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java new file mode 100644 index 000000000000..e63fa3070f96 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Faiss is "a library for efficient + * similarity search and clustering of dense vectors", with support for various vector + * transforms, indexing algorithms, quantization techniques, etc. This package provides a pluggable + * Faiss-based format to perform vector searches in Lucene, via {@link + * org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat}. + * + *

To use this format: Install pytorch/faiss-cpu v{@value + * org.apache.lucene.sandbox.codecs.faiss.LibFaissC#LIBRARY_VERSION} from Conda and place shared libraries (including + * dependencies) on the {@code $LD_LIBRARY_PATH} environment variable or {@code -Djava.library.path} + * JVM argument. + * + *

Important: Ensure that the license of the Conda distribution and channels is applicable to + * you. pytorch and conda-forge are community-maintained channels with + * permissive licenses! + * + *

Sample setup: + * + *

+ * + * @lucene.experimental + */ +package org.apache.lucene.sandbox.codecs.faiss; diff --git a/lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat new file mode 100644 index 000000000000..29a44d2ecfa8 --- /dev/null +++ b/lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java new file mode 100644 index 000000000000..4239e3d0b3b1 --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.index.VectorEncoding.FLOAT32; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; + +import java.io.IOException; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.junit.BeforeClass; +import org.junit.Ignore; + +/** + * Tests for {@link FaissKnnVectorsFormat}. Will run only if required shared libraries (including + * dependencies) are present at runtime, or the {@value #FAISS_RUN_TESTS} JVM arg is set to {@code + * true} + */ +public class TestFaissKnnVectorsFormat extends BaseKnnVectorsFormatTestCase { + private static final String FAISS_RUN_TESTS = "tests.faiss.run"; + + private static final VectorEncoding[] SUPPORTED_ENCODINGS = {FLOAT32}; + private static final VectorSimilarityFunction[] SUPPORTED_FUNCTIONS = {DOT_PRODUCT, EUCLIDEAN}; + + @BeforeClass + public static void maybeSuppress() throws ClassNotFoundException { + // Explicitly run tests + if (Boolean.getBoolean(FAISS_RUN_TESTS)) { + return; + } + + // Otherwise check if dependencies are present + boolean faissLibraryPresent; + try { + Class.forName("org.apache.lucene.sandbox.codecs.faiss.LibFaissC"); + faissLibraryPresent = true; + } catch (UnsatisfiedLinkError _) { + faissLibraryPresent = false; + } + assumeTrue("Native libraries present", faissLibraryPresent); + } + + @Override + protected VectorEncoding randomVectorEncoding() { + return SUPPORTED_ENCODINGS[random().nextInt(SUPPORTED_ENCODINGS.length)]; + } + + @Override + protected VectorSimilarityFunction randomSimilarity() { + return SUPPORTED_FUNCTIONS[random().nextInt(SUPPORTED_FUNCTIONS.length)]; + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(new FaissKnnVectorsFormat()); + } + + @Override + public void testRecall() throws IOException { + // only supports some functions + for (VectorSimilarityFunction similarity : SUPPORTED_FUNCTIONS) { + assertRecall(similarity, 0.5, 1.0); + } + } + + @Override + @Ignore // does not honour visitedLimit + public void testSearchWithVisitedLimit() {} + + @Override + @Ignore // does not support byte vectors + public void testByteVectorScorerIteration() {} + + @Override + @Ignore // does not support byte vectors + public void testMismatchedFields() {} + + @Override + @Ignore // does not support byte vectors + public void testSortedIndexBytes() {} + + @Override + @Ignore // does not support byte vectors + public void testRandomBytes() {} + + @Override + @Ignore // does not support byte vectors + public void testEmptyByteVectorData() {} + + @Override + @Ignore // does not support byte vectors + public void testMergingWithDifferentByteKnnFields() {} +}