diff --git a/ci/build_java.sh b/ci/build_java.sh
index 1338da8..43536e2 100755
--- a/ci/build_java.sh
+++ b/ci/build_java.sh
@@ -32,6 +32,27 @@ set -u
rapids-print-env
+# Locates the libcuvs.so file path and appends it to LD_LIBRARY_PATH
+rapids-logger "Find libcuvs so file and prepend paths to LD_LIBRARY_PATH"
+
+CONDA_PKG_CACHE_DIR="/opt/conda/pkgs" # comes from `conda info`. Dont know if this ever changes.
+if [ -d "$CONDA_PKG_CACHE_DIR" ]; then
+ echo "==> Directory '$CONDA_PKG_CACHE_DIR' exists."
+ LIBCUVS_SO_FILE="libcuvs.so"
+ LIBCUVS_PATH=$(find $CONDA_PKG_CACHE_DIR -name $LIBCUVS_SO_FILE)
+ if [ -z "$LIBCUVS_PATH" ]; then
+ echo "==> Could not find the so file. Not updating LD_LIBRARY_PATH"
+ exit 1
+ else
+ LIBCUVS_DIR=$(dirname "$LIBCUVS_PATH")
+ export LD_LIBRARY_PATH="$LIBCUVS_DIR:$LD_LIBRARY_PATH"
+ echo "LD_LIBRARY_PATH is: $LD_LIBRARY_PATH"
+ fi
+else
+ echo "==> Directory '$CONDA_PKG_CACHE_DIR' does not exist. Not updating LD_LIBRARY_PATH"
+ exit 1
+fi
+
rapids-logger "Run Java build"
bash ./build.sh "${EXTRA_BUILD_ARGS[@]}"
diff --git a/conda/environments/all_cuda-129_arch-aarch64.yaml b/conda/environments/all_cuda-129_arch-aarch64.yaml
index dd3018c..0b31020 100644
--- a/conda/environments/all_cuda-129_arch-aarch64.yaml
+++ b/conda/environments/all_cuda-129_arch-aarch64.yaml
@@ -2,6 +2,7 @@
# To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.
channels:
- conda-forge
+- rapidsai-nightly
dependencies:
- cuda-cudart-dev
- cuda-nvtx-dev
@@ -11,6 +12,7 @@ dependencies:
- libcurand-dev
- libcusolver-dev
- libcusparse-dev
+- libcuvs
- maven
- openjdk=22.*
name: all_cuda-129_arch-aarch64
diff --git a/conda/environments/all_cuda-129_arch-x86_64.yaml b/conda/environments/all_cuda-129_arch-x86_64.yaml
index 8d06710..a8f0444 100644
--- a/conda/environments/all_cuda-129_arch-x86_64.yaml
+++ b/conda/environments/all_cuda-129_arch-x86_64.yaml
@@ -2,6 +2,7 @@
# To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.
channels:
- conda-forge
+- rapidsai-nightly
dependencies:
- cuda-cudart-dev
- cuda-nvtx-dev
@@ -11,6 +12,7 @@ dependencies:
- libcurand-dev
- libcusolver-dev
- libcusparse-dev
+- libcuvs
- maven
- openjdk=22.*
name: all_cuda-129_arch-x86_64
diff --git a/conda/environments/all_cuda-130_arch-aarch64.yaml b/conda/environments/all_cuda-130_arch-aarch64.yaml
index 77eaa69..50e93bd 100644
--- a/conda/environments/all_cuda-130_arch-aarch64.yaml
+++ b/conda/environments/all_cuda-130_arch-aarch64.yaml
@@ -2,6 +2,7 @@
# To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.
channels:
- conda-forge
+- rapidsai-nightly
dependencies:
- cuda-cudart-dev
- cuda-nvtx-dev
@@ -11,6 +12,7 @@ dependencies:
- libcurand-dev
- libcusolver-dev
- libcusparse-dev
+- libcuvs
- maven
- openjdk=22.*
name: all_cuda-130_arch-aarch64
diff --git a/conda/environments/all_cuda-130_arch-x86_64.yaml b/conda/environments/all_cuda-130_arch-x86_64.yaml
index ee8a720..c1e7ab3 100644
--- a/conda/environments/all_cuda-130_arch-x86_64.yaml
+++ b/conda/environments/all_cuda-130_arch-x86_64.yaml
@@ -2,6 +2,7 @@
# To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.
channels:
- conda-forge
+- rapidsai-nightly
dependencies:
- cuda-cudart-dev
- cuda-nvtx-dev
@@ -11,6 +12,7 @@ dependencies:
- libcurand-dev
- libcusolver-dev
- libcusparse-dev
+- libcuvs
- maven
- openjdk=22.*
name: all_cuda-130_arch-x86_64
diff --git a/dependencies.yaml b/dependencies.yaml
index 4c6cec0..ee06135 100644
--- a/dependencies.yaml
+++ b/dependencies.yaml
@@ -21,6 +21,7 @@ files:
- java
channels:
- conda-forge
+ - rapidsai-nightly
dependencies:
checks:
common:
@@ -66,6 +67,7 @@ dependencies:
- libcurand-dev
- libcusolver-dev
- libcusparse-dev
+ - libcuvs
java:
common:
- output_types: conda
diff --git a/pom.xml b/pom.xml
index 0737819..6400cc4 100644
--- a/pom.xml
+++ b/pom.xml
@@ -69,7 +69,7 @@
com.nvidia.cuvs
cuvs-java
- 25.8.0-4f53f-SNAPSHOT
+ 25.10.0-2c0e1-SNAPSHOT
diff --git a/src/main/java/com/nvidia/cuvs/lucene/CuVSIndex.java b/src/main/java/com/nvidia/cuvs/lucene/CuVSIndex.java
index 78c5dc1..1151980 100644
--- a/src/main/java/com/nvidia/cuvs/lucene/CuVSIndex.java
+++ b/src/main/java/com/nvidia/cuvs/lucene/CuVSIndex.java
@@ -15,8 +15,6 @@
*/
package com.nvidia.cuvs.lucene;
-import static com.nvidia.cuvs.lucene.CuVSVectorsReader.handleThrowable;
-
import com.nvidia.cuvs.BruteForceIndex;
import com.nvidia.cuvs.CagraIndex;
import com.nvidia.cuvs.HnswIndex;
@@ -103,16 +101,16 @@ public void close() throws IOException {
private void destroyIndices() throws IOException {
try {
if (cagraIndex != null) {
- cagraIndex.destroyIndex();
+ cagraIndex.close();
}
if (bruteforceIndex != null) {
- bruteforceIndex.destroyIndex();
+ bruteforceIndex.close();
}
if (hnswIndex != null) {
- hnswIndex.destroyIndex();
+ hnswIndex.close();
}
} catch (Throwable t) {
- handleThrowable(t);
+ Utils.handleThrowable(t);
}
}
}
diff --git a/src/main/java/com/nvidia/cuvs/lucene/CuVSKnnFloatVectorQuery.java b/src/main/java/com/nvidia/cuvs/lucene/CuVSKnnFloatVectorQuery.java
index fa1e71d..8caf30a 100644
--- a/src/main/java/com/nvidia/cuvs/lucene/CuVSKnnFloatVectorQuery.java
+++ b/src/main/java/com/nvidia/cuvs/lucene/CuVSKnnFloatVectorQuery.java
@@ -19,6 +19,7 @@
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.Bits;
@@ -29,8 +30,9 @@ public class CuVSKnnFloatVectorQuery extends KnnFloatVectorQuery {
private final int iTopK;
private final int searchWidth;
- public CuVSKnnFloatVectorQuery(String field, float[] target, int k, int iTopK, int searchWidth) {
- super(field, target, k);
+ public CuVSKnnFloatVectorQuery(
+ String field, float[] target, int k, Query filter, int iTopK, int searchWidth) {
+ super(field, target, k, filter);
this.iTopK = iTopK;
this.searchWidth = searchWidth;
}
@@ -46,7 +48,7 @@ protected TopDocs approximateSearch(
PerLeafCuVSKnnCollector results = new PerLeafCuVSKnnCollector(k, iTopK, searchWidth);
LeafReader reader = context.reader();
- reader.searchNearestVectors(field, this.getTargetCopy(), results, null);
+ reader.searchNearestVectors(field, this.getTargetCopy(), results, acceptDocs);
return results.topDocs();
}
}
diff --git a/src/main/java/com/nvidia/cuvs/lucene/CuVSVectorsFormat.java b/src/main/java/com/nvidia/cuvs/lucene/CuVSVectorsFormat.java
index f82198b..1f7e439 100644
--- a/src/main/java/com/nvidia/cuvs/lucene/CuVSVectorsFormat.java
+++ b/src/main/java/com/nvidia/cuvs/lucene/CuVSVectorsFormat.java
@@ -86,6 +86,12 @@ public CuVSVectorsFormat(
}
private static CuVSResources cuVSResourcesOrNull() {
+ try {
+ System.loadLibrary(
+ "cudart"); // nocommit: this is here so as to pass CI, should goto cuvs-java
+ } catch (UnsatisfiedLinkError e) {
+ LOG.warning("Could not load CUDA runtime library: " + e.getMessage());
+ }
try {
resources = CuVSResources.create();
return resources;
diff --git a/src/main/java/com/nvidia/cuvs/lucene/CuVSVectorsReader.java b/src/main/java/com/nvidia/cuvs/lucene/CuVSVectorsReader.java
index c770015..4118a0a 100644
--- a/src/main/java/com/nvidia/cuvs/lucene/CuVSVectorsReader.java
+++ b/src/main/java/com/nvidia/cuvs/lucene/CuVSVectorsReader.java
@@ -271,7 +271,7 @@ private CuVSIndex loadCuVSIndex(FieldEntry fieldEntry) throws IOException {
}
}
} catch (Throwable t) {
- handleThrowable(t);
+ Utils.handleThrowable(t);
}
return new CuVSIndex(cagraIndex, bruteForceIndex, hnswIndex);
}
@@ -367,7 +367,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
try {
searchResult = cagraIndex.search(query).getResults();
} catch (Throwable t) {
- handleThrowable(t);
+ Utils.handleThrowable(t);
}
// List expected to have only one entry because of single query "target".
assert searchResult.size() == 1;
@@ -385,7 +385,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
try {
searchResult = bruteforceIndex.search(query).getResults();
} catch (Throwable t) {
- handleThrowable(t);
+ Utils.handleThrowable(t);
}
assert searchResult.size() == 1;
result = searchResult.getFirst();
@@ -472,12 +472,15 @@ static void checkVersion(int versionMeta, int versionVectorData, IndexInput in)
}
}
- static void handleThrowable(Throwable t) throws IOException {
- switch (t) {
- case IOException ioe -> throw ioe;
- case Error error -> throw error;
- case RuntimeException re -> throw re;
- case null, default -> throw new RuntimeException("UNEXPECTED: exception type", t);
- }
+ public FieldInfos getFieldInfos() {
+ return fieldInfos;
+ }
+
+ public IntObjectHashMap getCuvsIndexes() {
+ return cuvsIndices;
+ }
+
+ public IntObjectHashMap getFieldEntries() {
+ return fields;
}
}
diff --git a/src/main/java/com/nvidia/cuvs/lucene/CuVSVectorsWriter.java b/src/main/java/com/nvidia/cuvs/lucene/CuVSVectorsWriter.java
index 3e22b4d..7861ad1 100644
--- a/src/main/java/com/nvidia/cuvs/lucene/CuVSVectorsWriter.java
+++ b/src/main/java/com/nvidia/cuvs/lucene/CuVSVectorsWriter.java
@@ -20,7 +20,6 @@
import static com.nvidia.cuvs.lucene.CuVSVectorsFormat.CUVS_META_CODEC_EXT;
import static com.nvidia.cuvs.lucene.CuVSVectorsFormat.CUVS_META_CODEC_NAME;
import static com.nvidia.cuvs.lucene.CuVSVectorsFormat.VERSION_CURRENT;
-import static com.nvidia.cuvs.lucene.CuVSVectorsReader.handleThrowable;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
import static org.apache.lucene.index.VectorEncoding.FLOAT32;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
@@ -41,14 +40,18 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
+import java.util.function.Supplier;
import java.util.logging.Logger;
+import java.util.stream.IntStream;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
+import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
@@ -57,6 +60,7 @@
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.Sorter.DocMap;
import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
@@ -231,7 +235,7 @@ private void writeCagraIndex(OutputStream os, CuVSMatrix dataset) throws Throwab
info("Cagra index created in " + elapsedMillis + "ms, with " + dataset.size() + " vectors");
Path tmpFile = Files.createTempFile(resources.tempDirectory(), "tmpindex", "cag");
index.serialize(os, tmpFile);
- index.destroyIndex();
+ index.close();
}
private void writeBruteForceIndex(OutputStream os, CuVSMatrix dataset) throws Throwable {
@@ -246,7 +250,7 @@ private void writeBruteForceIndex(OutputStream os, CuVSMatrix dataset) throws Th
long elapsedMillis = nanosToMillis(System.nanoTime() - startTime);
info("bf index created in " + elapsedMillis + "ms, with " + dataset.size() + " vectors");
index.serialize(os);
- index.destroyIndex();
+ index.close();
}
private void writeHNSWIndex(OutputStream os, CuVSMatrix dataset) throws Throwable {
@@ -261,7 +265,7 @@ private void writeHNSWIndex(OutputStream os, CuVSMatrix dataset) throws Throwabl
info("HNSW index created in " + elapsedMillis + "ms, with " + dataset.size() + " vectors");
Path tmpFile = Files.createTempFile("tmpindex", "hnsw");
index.serializeToHNSW(os, tmpFile);
- index.destroyIndex();
+ index.close();
}
@Override
@@ -277,37 +281,36 @@ public void flush(int maxDoc, DocMap sortMap) throws IOException {
}
private void writeField(CuVSFieldWriter fieldData) throws IOException {
- // TODO: Argh! https://github.com/rapidsai/cuvs/issues/698
+ // TODO: Loading all vectors into memory is inefficient. Is there a way to stream the vectors
+ // from the flat writer to the CuVSMatrix?
List vectors = fieldData.getVectors();
- CuVSMatrix.Builder builder =
- CuVSMatrix.builder(
- vectors.size(), fieldData.fieldInfo().getVectorDimension(), CuVSMatrix.DataType.FLOAT);
- for (float[] vec : vectors) builder.addVector(vec);
- writeFieldInternal(fieldData.fieldInfo(), builder.build());
+ writeFieldInternal(
+ fieldData.fieldInfo(),
+ () -> Utils.createFloatMatrix(vectors, fieldData.fieldInfo().getVectorDimension()),
+ vectors.size());
}
private void writeSortingField(CuVSFieldWriter fieldData, Sorter.DocMap sortMap)
throws IOException {
DocsWithFieldSet oldDocsWithFieldSet = fieldData.getDocsWithFieldSet();
final int[] new2OldOrd = new int[oldDocsWithFieldSet.cardinality()]; // new ord to old ord
-
mapOldOrdToNewOrd(oldDocsWithFieldSet, sortMap, null, new2OldOrd, null);
-
- float[][] oldVectors = fieldData.getVectors().toArray(float[][]::new);
- CuVSMatrix.Builder builder =
- CuVSMatrix.builder(
- fieldData.getVectors().size(),
- fieldData.fieldInfo().getVectorDimension(),
- CuVSMatrix.DataType.FLOAT);
- for (int i = 0; i < oldVectors.length; i++) {
- float[] vec = oldVectors[new2OldOrd[i]];
- builder.addVector(vec);
+ // TODO: Loading all vectors into memory is inefficient. Is there a way to stream the vectors
+ // from the flat writer to the CuVSMatrix?
+ List sortedVectors = new ArrayList();
+ for (int i = 0; i < fieldData.getVectors().size(); i++) {
+ sortedVectors.add(fieldData.getVectors().get(new2OldOrd[i]));
}
- writeFieldInternal(fieldData.fieldInfo(), builder.build());
+ writeFieldInternal(
+ fieldData.fieldInfo(),
+ () -> Utils.createFloatMatrix(sortedVectors, fieldData.fieldInfo().getVectorDimension()),
+ sortedVectors.size());
}
- private void writeFieldInternal(FieldInfo fieldInfo, CuVSMatrix dataset) throws IOException {
- if (dataset.size() == 0) {
+ private void writeFieldInternal(
+ FieldInfo fieldInfo, Supplier datasetSupplier, int datasetSize)
+ throws IOException {
+ if (datasetSize == 0) {
writeEmpty(fieldInfo);
return;
}
@@ -317,7 +320,7 @@ private void writeFieldInternal(FieldInfo fieldInfo, CuVSMatrix dataset) throws
// workaround for the minimum number of vectors for Cagra
IndexType indexType =
- this.indexType.cagra() && dataset.size() < MIN_CAGRA_INDEX_SIZE
+ this.indexType.cagra() && datasetSize < MIN_CAGRA_INDEX_SIZE
? IndexType.BRUTE_FORCE
: this.indexType;
@@ -326,7 +329,7 @@ private void writeFieldInternal(FieldInfo fieldInfo, CuVSMatrix dataset) throws
if (indexType.cagra()) {
try {
var cagraIndexOutputStream = new IndexOutputOutputStream(cuvsIndex);
- writeCagraIndex(cagraIndexOutputStream, dataset);
+ writeCagraIndex(cagraIndexOutputStream, datasetSupplier.get());
} catch (Throwable t) {
handleThrowableWithIgnore(t, CANNOT_GENERATE_CAGRA);
// workaround for cuVS issue
@@ -338,16 +341,16 @@ private void writeFieldInternal(FieldInfo fieldInfo, CuVSMatrix dataset) throws
bruteForceIndexOffset = cuvsIndex.getFilePointer();
if (indexType.bruteForce()) {
var bruteForceIndexOutputStream = new IndexOutputOutputStream(cuvsIndex);
- writeBruteForceIndex(bruteForceIndexOutputStream, dataset);
+ writeBruteForceIndex(bruteForceIndexOutputStream, datasetSupplier.get());
bruteForceIndexLength = cuvsIndex.getFilePointer() - bruteForceIndexOffset;
}
hnswIndexOffset = cuvsIndex.getFilePointer();
if (indexType.hnsw()) {
var hnswIndexOutputStream = new IndexOutputOutputStream(cuvsIndex);
- if (dataset.size() > MIN_CAGRA_INDEX_SIZE) {
+ if (datasetSize > MIN_CAGRA_INDEX_SIZE) {
try {
- writeHNSWIndex(hnswIndexOutputStream, dataset);
+ writeHNSWIndex(hnswIndexOutputStream, datasetSupplier.get());
} catch (Throwable t) {
handleThrowableWithIgnore(t, CANNOT_GENERATE_CAGRA);
}
@@ -357,7 +360,7 @@ private void writeFieldInternal(FieldInfo fieldInfo, CuVSMatrix dataset) throws
writeMeta(
fieldInfo,
- (int) dataset.size(),
+ (int) datasetSize,
cagraIndexOffset,
cagraIndexLength,
bruteForceIndexOffset,
@@ -365,7 +368,7 @@ private void writeFieldInternal(FieldInfo fieldInfo, CuVSMatrix dataset) throws
hnswIndexOffset,
hnswIndexLength);
} catch (Throwable t) {
- handleThrowable(t);
+ Utils.handleThrowable(t);
}
}
@@ -418,46 +421,184 @@ static void handleThrowableWithIgnore(Throwable t, String msg) throws IOExceptio
if (t.getMessage().contains(msg)) {
return;
}
- handleThrowable(t);
+ Utils.handleThrowable(t);
+ }
+
+ private void mergeCagraIndexes(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+ try {
+
+ List cagraIndexes = new ArrayList<>();
+ // We need this count so that the merged segment's meta information has the vector count.
+ int totalVectorCount = 0;
+
+ for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
+ KnnVectorsReader knnReader = mergeState.knnVectorsReaders[i];
+ // Access the CAGRA index for this field from the reader
+
+ if (knnReader != null) {
+ if (knnReader instanceof CuVSVectorsReader cvr) {
+ if (cvr != null) {
+ totalVectorCount += cvr.getFieldEntries().get(fieldInfo.number).count();
+ CagraIndex cagraIndex = getCagraIndexFromReader(cvr, fieldInfo.name);
+ if (cagraIndex != null) {
+ cagraIndexes.add(cagraIndex);
+ }
+ }
+ } else {
+ // This should never happen
+ throw new RuntimeException(
+ "Reader is not of CuVSVectorsReader type. Instead it is: " + knnReader.getClass());
+ }
+ }
+ }
+ assert cagraIndexes.size() > 1;
+
+ CagraIndex mergedIndex =
+ CagraIndex.merge(cagraIndexes.toArray(new CagraIndex[cagraIndexes.size()]));
+ writeMergedCagraIndex(fieldInfo, mergedIndex, totalVectorCount);
+ info("Successfully merged " + cagraIndexes.size() + " CAGRA indexes using native merge API");
+
+ } catch (Throwable t) {
+ Utils.handleThrowable(t);
+ }
}
/**
- * Copies the vector values into dst. Returns the actual number of vectors
- * copied.
+ * Fallback method that rebuilds indexes from merged vectors.
+ * Used when native CAGRA merge() is not possible. Also used
+ * when non-CAGRA index types are used (for e.g. Brute Force index).
*/
- private static int getVectorData(FloatVectorValues floatVectorValues, CuVSMatrix.Builder builder)
- throws IOException {
- DocsWithFieldSet docsWithField = new DocsWithFieldSet();
- int count = 0;
- KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator();
+ private void vectorBasedMerge(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+ if (fieldInfo.getVectorEncoding() != FLOAT32) {
+ throw new AssertionError("Only Float32 supported");
+ }
+ try {
+ // We need to compute the size of the number of merged documents up-front so that we can
+ // compute the CuVSMatrix capacity. TODO: Find a way to do this without merging twice.
+ final int numMergedDocs = getMergedDocsCount(fieldInfo, mergeState);
+
+ if (numMergedDocs != 0) {
+ writeFieldInternal(
+ fieldInfo,
+ () -> {
+ try {
+ return createMatrixFromMergedVectors(
+ KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(
+ fieldInfo, mergeState),
+ numMergedDocs);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ },
+ numMergedDocs);
+ } else {
+ writeEmpty(fieldInfo);
+ }
+ } catch (Throwable t) {
+ Utils.handleThrowable(t);
+ }
+ }
+
+ private int getMergedDocsCount(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+ KnnVectorValues.DocIndexIterator iter =
+ KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)
+ .iterator();
+ int numMergedDocs = 0;
+ for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
+ numMergedDocs++;
+ }
+ return numMergedDocs;
+ }
+
+ /**
+ * Creates CuVSMatrix from merged vectors
+ * */
+ private CuVSMatrix createMatrixFromMergedVectors(
+ FloatVectorValues mergedVectorValues, int numMergedDocs) throws IOException {
+ List vectors = new ArrayList<>(numMergedDocs);
+ KnnVectorValues.DocIndexIterator iter = mergedVectorValues.iterator();
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
- assert iter.index() == count;
- builder.addVector(floatVectorValues.vectorValue(iter.index())); // is this correct?
- docsWithField.add(docV);
- count++;
+ int ordinal = iter.index();
+ float[] vector = mergedVectorValues.vectorValue(ordinal);
+ vectors.add(vector.clone());
+ }
+ return Utils.createFloatMatrix(vectors, mergedVectorValues.dimension());
+ }
+
+ /**
+ * Extracts the CAGRA index for a specific field from a CuVSVectorsReader.
+ */
+ private CagraIndex getCagraIndexFromReader(CuVSVectorsReader reader, String fieldName) {
+ try {
+ IntObjectHashMap cuvsIndices = reader.getCuvsIndexes();
+ FieldInfos fieldInfos = reader.getFieldInfos();
+
+ FieldInfo fieldInfo = fieldInfos.fieldInfo(fieldName);
+
+ if (fieldInfo != null) {
+ CuVSIndex cuvsIndex = cuvsIndices.get(fieldInfo.number);
+ if (cuvsIndex != null) {
+ return cuvsIndex.getCagraIndex();
+ }
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ info("Failed to extract CAGRA index for field " + fieldName + ": " + e.getMessage());
+ }
+ return null;
+ }
+
+ /**
+ * Writes a pre-built merged CAGRA index to the output.
+ */
+ private void writeMergedCagraIndex(FieldInfo fieldInfo, CagraIndex mergedIndex, int vectorCount)
+ throws IOException {
+ try {
+ long cagraIndexOffset = cuvsIndex.getFilePointer();
+ var cagraIndexOutputStream = new IndexOutputOutputStream(cuvsIndex);
+
+ // Serialize the merged index
+ Path tmpFile = Files.createTempFile(resources.tempDirectory(), "mergedindex", "cag");
+ mergedIndex.serialize(cagraIndexOutputStream, tmpFile);
+ long cagraIndexLength = cuvsIndex.getFilePointer() - cagraIndexOffset;
+
+ // Write metadata (assuming no brute force or HNSW indexes for merged result)
+ writeMeta(fieldInfo, vectorCount, cagraIndexOffset, cagraIndexLength, 0L, 0L, 0L, 0L);
+
+ // Clean up the merged index
+ mergedIndex.close();
+ } catch (Throwable t) {
+ Utils.handleThrowable(t);
}
- return docsWithField.cardinality();
}
@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
flatVectorsWriter.mergeOneField(fieldInfo, mergeState);
- try {
- final FloatVectorValues mergedVectorValues =
- switch (fieldInfo.getVectorEncoding()) {
- case BYTE -> throw new AssertionError("bytes not supported");
- case FLOAT32 ->
- KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
- };
-
- // Also will be replaced with the cuVS merge api
- CuVSMatrix.Builder builder =
- CuVSMatrix.builder(
- mergedVectorValues.size(), mergedVectorValues.dimension(), CuVSMatrix.DataType.FLOAT);
- getVectorData(mergedVectorValues, builder);
- writeFieldInternal(fieldInfo, builder.build());
- } catch (Throwable t) {
- handleThrowable(t);
+
+ if (indexType.cagra() && !indexType.bruteForce()) {
+ // Since CAGRA merge does not support merging of indexes with purging of deletes,
+ // we fallback to vector-based re-indexing. Issue:
+ // https://github.com/rapidsai/cuvs/issues/1253
+ boolean hasDeletions =
+ IntStream.range(0, mergeState.liveDocs.length)
+ .anyMatch(
+ i ->
+ mergeState.liveDocs[i] == null
+ || IntStream.range(0, mergeState.maxDocs[i])
+ .anyMatch(j -> !mergeState.liveDocs[i].get(j)));
+
+ if (mergeState.knnVectorsReaders.length > 1 && !hasDeletions) {
+ mergeCagraIndexes(fieldInfo, mergeState);
+ } else {
+ // CAGRA's merge API does not handle the trivial case of merging 1 index.
+ vectorBasedMerge(fieldInfo, mergeState);
+ }
+
+ } else {
+ // If there is a Brute Force index then re-index using the vectors even if there is a CAGRA
+ // index.
+ vectorBasedMerge(fieldInfo, mergeState);
}
}
diff --git a/src/main/java/com/nvidia/cuvs/lucene/FilterCuVSProvider.java b/src/main/java/com/nvidia/cuvs/lucene/FilterCuVSProvider.java
index 5c3f7d1..05acbee 100644
--- a/src/main/java/com/nvidia/cuvs/lucene/FilterCuVSProvider.java
+++ b/src/main/java/com/nvidia/cuvs/lucene/FilterCuVSProvider.java
@@ -68,8 +68,25 @@ public CagraIndex mergeCagraIndexes(CagraIndex[] arg0) throws Throwable {
}
@Override
- public Builder newMatrixBuilder(int size, int dimensions, DataType dataType) {
- return delegate.newMatrixBuilder(size, dimensions, dataType);
+ public com.nvidia.cuvs.GPUInfoProvider gpuInfoProvider() {
+ return delegate.gpuInfoProvider();
+ }
+
+ @Override
+ public Builder newHostMatrixBuilder(long rows, long cols, DataType dataType) {
+ return delegate.newHostMatrixBuilder(rows, cols, dataType);
+ }
+
+ @Override
+ public Builder newDeviceMatrixBuilder(
+ CuVSResources resources, long rows, long cols, DataType dataType) {
+ return delegate.newDeviceMatrixBuilder(resources, rows, cols, dataType);
+ }
+
+ @Override
+ public Builder newDeviceMatrixBuilder(
+ CuVSResources resources, long rows, long cols, int maxRows, int maxCols, DataType dataType) {
+ return delegate.newDeviceMatrixBuilder(resources, rows, cols, maxRows, maxCols, dataType);
}
@Override
diff --git a/src/main/java/com/nvidia/cuvs/lucene/Utils.java b/src/main/java/com/nvidia/cuvs/lucene/Utils.java
new file mode 100644
index 0000000..8d1d4bd
--- /dev/null
+++ b/src/main/java/com/nvidia/cuvs/lucene/Utils.java
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2025, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.cuvs.lucene;
+
+import com.nvidia.cuvs.CuVSMatrix;
+import java.io.IOException;
+import java.util.List;
+
+public class Utils {
+
+ static void handleThrowable(Throwable t) throws IOException {
+ switch (t) {
+ case IOException ioe -> throw ioe;
+ case Error error -> throw error;
+ case RuntimeException re -> throw re;
+ case null, default -> throw new RuntimeException("UNEXPECTED: exception type", t);
+ }
+ }
+
+ /**
+ * A method to build a {@link CuVSMatrix} from a list of float vectors.
+ *
+ * Note: This could be a memory-intensive operation and should therefore be avoided.
+ * Consider using this {@link CuVSMatrix.Builder} instead for copying the vectors without loading them in heap.
+ *
+ * @param data The float vectors
+ * @param dimensions The number float elements in each vector
+ * @return an instance of {@link CuVSMatrix}
+ */
+ static CuVSMatrix createFloatMatrix(List data, int dimensions) {
+ // Convert List to float[][] for the ofArray method
+ float[][] vectors = data.toArray(new float[0][]);
+ return CuVSMatrix.ofArray(vectors);
+ }
+}
diff --git a/src/test/java/com/nvidia/cuvs/lucene/TestCuVSDeletedDocuments.java b/src/test/java/com/nvidia/cuvs/lucene/TestCuVSDeletedDocuments.java
new file mode 100644
index 0000000..ef17136
--- /dev/null
+++ b/src/test/java/com/nvidia/cuvs/lucene/TestCuVSDeletedDocuments.java
@@ -0,0 +1,342 @@
+/*
+ * Copyright (c) 2025, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.cuvs.lucene;
+
+import static com.nvidia.cuvs.lucene.TestUtils.generateDataset;
+import static com.nvidia.cuvs.lucene.TestUtils.generateRandomVector;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.logging.Logger;
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.document.StringField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.analysis.MockAnalyzer;
+import org.apache.lucene.tests.analysis.MockTokenizer;
+import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.tests.util.LuceneTestCase.SuppressSysoutChecks;
+import org.apache.lucene.tests.util.TestUtil;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+@SuppressSysoutChecks(bugUrl = "prints info from within cuvs")
+public class TestCuVSDeletedDocuments extends LuceneTestCase {
+
+ protected static Logger log = Logger.getLogger(TestCuVSDeletedDocuments.class.getName());
+
+ static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new CuVSVectorsFormat());
+ private static Random random;
+
+ @BeforeClass
+ public static void beforeClass() throws Exception {
+ assumeTrue("cuvs not supported", CuVSVectorsFormat.supported());
+ random = random();
+ }
+
+ @Test
+ public void testVectorSearchWithDeletedDocuments() throws IOException {
+
+ try (Directory directory = newDirectory()) {
+ int datasetSize = random.nextInt(200, 1000); // 200-1200 documents
+ int dimensions = random.nextInt(64, 256); // 64-320 dimensions
+ int topK = Math.min(random.nextInt(20) + 5, datasetSize / 2); // 5-25 results
+ float deletionProbability = random.nextFloat() * 0.4f + 0.1f; // 10-50% deletion rate
+
+ float[][] dataset = generateDataset(random, datasetSize, dimensions);
+ Set deletedDocs = new HashSet<>();
+
+ // Create index with all documents having vectors
+ try (RandomIndexWriter writer = createWriter(directory)) {
+ for (int i = 0; i < datasetSize; i++) {
+ Document doc = new Document();
+ doc.add(new StringField("id", String.valueOf(i), Field.Store.YES));
+ doc.add(
+ new KnnFloatVectorField("vector", dataset[i], VectorSimilarityFunction.EUCLIDEAN));
+ writer.addDocument(doc);
+ }
+
+ // Delete documents randomly based on probability
+ for (int i = 0; i < datasetSize; i++) {
+ if (random.nextFloat() < deletionProbability) {
+ writer.deleteDocuments(new Term("id", String.valueOf(i)));
+ deletedDocs.add(i);
+ }
+ }
+ writer.commit();
+ }
+
+ // Search and verify deleted documents are not returned
+ try (DirectoryReader reader = DirectoryReader.open(directory)) {
+ IndexSearcher searcher = newSearcher(reader);
+ // Use a random vector for query
+ float[] queryVector = generateRandomVector(dimensions, random);
+
+ Query query = new KnnFloatVectorQuery("vector", queryVector, topK);
+ ScoreDoc[] hits = searcher.search(query, topK).scoreDocs;
+
+ // Verify we got results
+ assertTrue("Should have search results", hits.length > 0);
+
+ // Verify no deleted documents in results
+ for (ScoreDoc hit : hits) {
+ String docId = reader.storedFields().document(hit.doc).get("id");
+ int id = Integer.parseInt(docId);
+ assertFalse(
+ "Deleted document " + id + " should not appear in results", deletedDocs.contains(id));
+ log.info("Found non-deleted document: " + id + ", Score: " + hit.score);
+ }
+
+ // Verify deleted documents are truly deleted
+ for (int deletedId : deletedDocs) {
+ TopDocs result =
+ searcher.search(new TermQuery(new Term("id", String.valueOf(deletedId))), 1);
+ assertEquals(
+ "Deleted document " + deletedId + " should not be found",
+ 0,
+ result.totalHits.value());
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testVectorSearchWithMixedDeletedAndMissingVectors() throws IOException {
+
+ try (Directory directory = newDirectory()) {
+ int datasetSize = random.nextInt(200) + 50; // 50-250 documents
+ int dimensions = random.nextInt(256) + 64; // 64-320 dimensions
+ int topK = Math.min(random.nextInt(20) + 5, datasetSize / 2); // 5-25 results
+ float vectorProbability = random.nextFloat() * 0.5f + 0.3f; // 30-80% have vectors
+ float deletionProbability = random.nextFloat() * 0.3f + 0.1f; // 10-40% deletion rate
+
+ float[][] dataset = generateDataset(random, datasetSize, dimensions);
+ Set docsWithoutVectors = new HashSet<>();
+ Set deletedDocs = new HashSet<>();
+
+ // Create index with mixed documents
+ try (RandomIndexWriter writer = createWriter(directory)) {
+ for (int i = 0; i < datasetSize; i++) {
+ Document doc = new Document();
+ doc.add(new StringField("id", String.valueOf(i), Field.Store.YES));
+ // Randomly assign categories
+ String category = random.nextBoolean() ? "A" : "B";
+ doc.add(new StringField("category", category, Field.Store.YES));
+
+ // Randomly decide whether to add vectors
+ if (random.nextFloat() < vectorProbability) {
+ doc.add(
+ new KnnFloatVectorField("vector", dataset[i], VectorSimilarityFunction.EUCLIDEAN));
+ } else {
+ docsWithoutVectors.add(i);
+ }
+ writer.addDocument(doc);
+ }
+
+ // Delete documents randomly
+ for (int i = 0; i < datasetSize; i++) {
+ if (random.nextFloat() < deletionProbability) {
+ writer.deleteDocuments(new Term("id", String.valueOf(i)));
+ deletedDocs.add(i);
+ }
+ }
+ writer.commit();
+ }
+
+ // Test vector search behavior
+ try (DirectoryReader reader = DirectoryReader.open(directory)) {
+ IndexSearcher searcher = newSearcher(reader);
+ float[] queryVector = generateRandomVector(dimensions, random);
+
+ Query query = new KnnFloatVectorQuery("vector", queryVector, topK);
+ ScoreDoc[] hits = searcher.search(query, topK).scoreDocs;
+
+ // Verify results
+ for (ScoreDoc hit : hits) {
+ String docId = reader.storedFields().document(hit.doc).get("id");
+ int id = Integer.parseInt(docId);
+ assertFalse("Deleted document should not appear", deletedDocs.contains(id));
+ assertFalse("Document without vector should not appear", docsWithoutVectors.contains(id));
+ log.info("Found document with vector: " + id + ", Score: " + hit.score);
+ }
+
+ // Test filtered search with deletions
+ Query filter = new TermQuery(new Term("category", "A"));
+ Query filteredQuery =
+ new CuVSKnnFloatVectorQuery("vector", queryVector, topK, filter, topK, 1);
+ ScoreDoc[] filteredHits = searcher.search(filteredQuery, topK).scoreDocs;
+
+ for (ScoreDoc hit : filteredHits) {
+ Document doc = reader.storedFields().document(hit.doc);
+ String category = doc.get("category");
+ assertEquals("Should only match category A", "A", category);
+ int id = Integer.parseInt(doc.get("id"));
+ assertFalse(
+ "Deleted document should not appear in filtered results", deletedDocs.contains(id));
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testVectorSearchAfterAllDocumentsDeleted() throws IOException {
+
+ try (Directory directory = newDirectory()) {
+ int datasetSize = random.nextInt(20) + 5; // 5-25 documents for this test
+ int dimensions = random.nextInt(128) + 32; // 32-160 dimensions
+ int topK = Math.min(random.nextInt(10) + 5, datasetSize); // 5-15 results
+
+ float[][] dataset = generateDataset(random, datasetSize, dimensions);
+
+ // Create and delete all documents
+ try (IndexWriter writer = new IndexWriter(directory, createWriterConfig())) {
+ for (int i = 0; i < datasetSize; i++) {
+ Document doc = new Document();
+ doc.add(new StringField("id", String.valueOf(i), Field.Store.YES));
+ doc.add(
+ new KnnFloatVectorField("vector", dataset[i], VectorSimilarityFunction.EUCLIDEAN));
+ writer.addDocument(doc);
+ }
+ writer.commit();
+
+ // Delete all documents
+ for (int i = 0; i < datasetSize; i++) {
+ writer.deleteDocuments(new Term("id", String.valueOf(i)));
+ }
+ writer.commit();
+ writer.forceMerge(1); // Force merge to apply deletions
+ }
+
+ // Verify search returns no results
+ try (DirectoryReader reader = DirectoryReader.open(directory)) {
+ IndexSearcher searcher = newSearcher(reader);
+ float[] queryVector = generateRandomVector(dimensions, random);
+
+ Query query = new KnnFloatVectorQuery("vector", queryVector, topK);
+ TopDocs results = searcher.search(query, topK);
+
+ assertEquals(
+ "Should return no results when all documents are deleted",
+ 0,
+ results.totalHits.value());
+ }
+ }
+ }
+
+ @Test
+ public void testVectorSearchWithPartialDeletionAndReindexing() throws IOException {
+
+ try (Directory directory = newDirectory()) {
+ int datasetSize = random.nextInt(200) + 50; // 50-250 documents
+ int dimensions = random.nextInt(256) + 64; // 64-320 dimensions
+ int topK = Math.min(random.nextInt(20) + 5, datasetSize / 2); // 5-25 results
+ float deletionProbability = random.nextFloat() * 0.3f + 0.1f; // 10-40% deletion rate
+
+ float[][] dataset = generateDataset(random, datasetSize, dimensions);
+ List activeDocIds = new ArrayList<>();
+
+ // Initial indexing
+ try (IndexWriter writer = new IndexWriter(directory, createWriterConfig())) {
+ int initialDocs = datasetSize / 2 + random.nextInt(datasetSize / 4); // 50-75% of dataset
+ for (int i = 0; i < initialDocs; i++) {
+ Document doc = new Document();
+ doc.add(new StringField("id", String.valueOf(i), Field.Store.YES));
+ doc.add(
+ new KnnFloatVectorField("vector", dataset[i], VectorSimilarityFunction.EUCLIDEAN));
+ writer.addDocument(doc);
+ activeDocIds.add(i);
+ }
+
+ // Delete some documents randomly
+ List candidatesForDeletion = new ArrayList<>(activeDocIds);
+ for (int docId : candidatesForDeletion) {
+ if (random.nextFloat() < deletionProbability) {
+ writer.deleteDocuments(new Term("id", String.valueOf(docId)));
+ activeDocIds.remove(Integer.valueOf(docId));
+ }
+ }
+
+ // Add new documents with higher IDs
+ for (int i = initialDocs; i < datasetSize; i++) {
+ Document doc = new Document();
+ doc.add(new StringField("id", String.valueOf(i), Field.Store.YES));
+ doc.add(
+ new KnnFloatVectorField("vector", dataset[i], VectorSimilarityFunction.EUCLIDEAN));
+ writer.addDocument(doc);
+ activeDocIds.add(i);
+ }
+ writer.commit();
+ }
+
+ // Verify search behavior after deletions and additions
+ try (DirectoryReader reader = DirectoryReader.open(directory)) {
+ IndexSearcher searcher = newSearcher(reader);
+ float[] queryVector = generateRandomVector(dimensions, random);
+
+ Query query = new KnnFloatVectorQuery("vector", queryVector, topK);
+ ScoreDoc[] hits = searcher.search(query, topK).scoreDocs;
+
+ Set resultIds = new HashSet<>();
+ for (ScoreDoc hit : hits) {
+ String docId = reader.storedFields().document(hit.doc).get("id");
+ int id = Integer.parseInt(docId);
+ resultIds.add(id);
+ assertTrue("Result should be from active documents", activeDocIds.contains(id));
+ }
+
+ log.info(
+ "Search returned "
+ + hits.length
+ + " results from "
+ + activeDocIds.size()
+ + " active documents");
+ }
+ }
+ }
+
+ private RandomIndexWriter createWriter(Directory directory) throws IOException {
+ return new RandomIndexWriter(
+ random(),
+ directory,
+ newIndexWriterConfig(new MockAnalyzer(random(), MockTokenizer.SIMPLE, true))
+ .setCodec(codec)
+ .setMergePolicy(newTieredMergePolicy()));
+ }
+
+ private IndexWriterConfig createWriterConfig() {
+ return newIndexWriterConfig(new MockAnalyzer(random(), MockTokenizer.SIMPLE, true))
+ .setCodec(codec)
+ .setMergePolicy(newTieredMergePolicy());
+ }
+}
diff --git a/src/test/java/com/nvidia/cuvs/lucene/TestCuVSGaps.java b/src/test/java/com/nvidia/cuvs/lucene/TestCuVSGaps.java
new file mode 100644
index 0000000..e27568b
--- /dev/null
+++ b/src/test/java/com/nvidia/cuvs/lucene/TestCuVSGaps.java
@@ -0,0 +1,197 @@
+/*
+ * Copyright (c) 2025, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.cuvs.lucene;
+
+import static com.nvidia.cuvs.lucene.TestUtils.generateDataset;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.TreeMap;
+import java.util.logging.Logger;
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.document.StringField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.analysis.MockAnalyzer;
+import org.apache.lucene.tests.analysis.MockTokenizer;
+import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.util.English;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.tests.util.LuceneTestCase.SuppressSysoutChecks;
+import org.apache.lucene.tests.util.TestUtil;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+@SuppressSysoutChecks(bugUrl = "prints info from within cuvs")
+public class TestCuVSGaps extends LuceneTestCase {
+
+ protected static Logger log = Logger.getLogger(TestCuVSGaps.class.getName());
+
+ static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new CuVSVectorsFormat());
+ static IndexSearcher searcher;
+ static IndexReader reader;
+ static Directory directory;
+ static Random random;
+
+ static int DATASET_SIZE_LIMIT = 1000;
+ static int DIMENSIONS_LIMIT = 2048;
+ static int NUM_QUERIES_LIMIT = 10;
+ static int TOP_K_LIMIT = 64;
+
+ static int datasetSize;
+ static int dimension;
+ static float[][] dataset;
+
+ @BeforeClass
+ public static void beforeClass() throws Exception {
+ assertTrue("cuvs not supported", CuVSVectorsFormat.supported());
+ directory = newDirectory();
+ random = random();
+
+ RandomIndexWriter writer =
+ new RandomIndexWriter(
+ random(),
+ directory,
+ newIndexWriterConfig(new MockAnalyzer(random(), MockTokenizer.SIMPLE, true))
+ .setMaxBufferedDocs(TestUtil.nextInt(random(), 100, 1000))
+ .setCodec(codec)
+ .setMergePolicy(newTieredMergePolicy()));
+
+ log.info("Merge Policy: " + writer.w.getConfig().getMergePolicy());
+
+ datasetSize = random.nextInt(100, DATASET_SIZE_LIMIT);
+ dimension = random.nextInt(8, DIMENSIONS_LIMIT);
+ dataset = generateDataset(random, datasetSize, dimension);
+
+ // Create documents where only even-numbered documents have vectors
+ for (int i = 0; i < datasetSize; i++) {
+ Document doc = new Document();
+ doc.add(new StringField("id", String.valueOf(i), Field.Store.YES));
+ doc.add(newTextField("field", English.intToEnglish(i), Field.Store.YES));
+
+ // Only add vectors to even-numbered documents
+ if (i % 2 == 0) {
+ doc.add(new KnnFloatVectorField("vector", dataset[i], VectorSimilarityFunction.EUCLIDEAN));
+ }
+
+ writer.addDocument(doc);
+ }
+
+ reader = writer.getReader();
+ searcher = newSearcher(reader);
+ writer.close();
+ }
+
+ @AfterClass
+ public static void afterClass() throws Exception {
+ if (reader != null) reader.close();
+ if (directory != null) directory.close();
+ searcher = null;
+ reader = null;
+ directory = null;
+ log.info("Test finished");
+ }
+
+ @Test
+ public void testVectorSearchWithAlternatingDocuments() throws IOException {
+ assertTrue("cuvs not supported", CuVSVectorsFormat.supported());
+
+ // Use the first vector (from document 0) as query
+ float[] queryVector = dataset[0];
+ int topK = random.nextInt(5, TOP_K_LIMIT);
+
+ Query query = new KnnFloatVectorQuery("vector", queryVector, topK);
+ ScoreDoc[] hits = searcher.search(query, topK).scoreDocs;
+
+ // Verify we get exactly TOP_K results
+ assertEquals("Should return exactly " + topK + " results", topK, hits.length);
+
+ // Verify all returned documents have vectors (even-numbered IDs)
+ for (ScoreDoc hit : hits) {
+ String docId = reader.storedFields().document(hit.doc).get("id");
+ int id = Integer.parseInt(docId);
+ assertEquals("All results should be even-numbered (have vectors)", 0, id % 2);
+ log.info("Document ID: " + id + ", Score: " + hit.score);
+ }
+
+ // Verify the results match expected top-k based on Euclidean distance
+ List expectedIds = calculateExpectedTopK(queryVector, topK, dataset);
+ for (int i = 0; i < hits.length; i++) {
+ String docId = reader.storedFields().document(hits[i].doc).get("id");
+ int id = Integer.parseInt(docId);
+ assertTrue("Result " + id + " should be in expected top-k results", expectedIds.contains(id));
+ }
+
+ log.info("Alternating document test passed with " + hits.length + " results");
+ }
+
+ @Test
+ public void testVectorSearchWithFilterAndAlternatingDocuments() throws IOException {
+ assumeTrue("cuvs not supported", CuVSVectorsFormat.supported());
+
+ // Use the first vector (from document 0) as query
+ float[] queryVector = dataset[0];
+ int topK = random.nextInt(5, TOP_K_LIMIT);
+
+ // Create a filter that only matches documents with ID less than 10
+ // This should further restrict our results to even numbers 0, 2, 4, 6, 8
+ Query filter = new TermQuery(new Term("id", "8")); // Only match document 8
+
+ Query filteredQuery = new CuVSKnnFloatVectorQuery("vector", queryVector, topK, filter, topK, 1);
+ ScoreDoc[] filteredHits = searcher.search(filteredQuery, topK).scoreDocs;
+
+ // Should only get document 8 (the only one that matches the filter and has a vector)
+ assertEquals("Should return exactly 1 result", 1, filteredHits.length);
+
+ String docId = reader.storedFields().document(filteredHits[0].doc).get("id");
+ assertEquals("Should only return document 8", "8", docId);
+
+ log.info("Filtered alternating document test passed with " + filteredHits.length + " results");
+ }
+
+ public static List calculateExpectedTopK(float[] query, int topK, float[][] dataset) {
+ Map distances = new TreeMap<>();
+
+ // Calculate distances only for documents that have vectors (even-numbered)
+ for (int i = 0; i < dataset.length; i += 2) {
+ double distance = 0;
+ for (int j = 0; j < dataset[0].length; j++) {
+ distance += (query[j] - dataset[i][j]) * (query[j] - dataset[i][j]);
+ }
+ distances.put(i, distance);
+ }
+
+ // Sort by distance and return top-k
+ return distances.entrySet().stream()
+ .sorted(Map.Entry.comparingByValue())
+ .map(Map.Entry::getKey)
+ .limit(topK)
+ .toList();
+ }
+}
diff --git a/src/test/java/com/nvidia/cuvs/lucene/TestCuVS.java b/src/test/java/com/nvidia/cuvs/lucene/TestCuVSRandomizedVectorSearch.java
similarity index 73%
rename from src/test/java/com/nvidia/cuvs/lucene/TestCuVS.java
rename to src/test/java/com/nvidia/cuvs/lucene/TestCuVSRandomizedVectorSearch.java
index 124c2f9..8802373 100644
--- a/src/test/java/com/nvidia/cuvs/lucene/TestCuVS.java
+++ b/src/test/java/com/nvidia/cuvs/lucene/TestCuVSRandomizedVectorSearch.java
@@ -15,6 +15,9 @@
*/
package com.nvidia.cuvs.lucene;
+import static com.nvidia.cuvs.lucene.TestUtils.generateDataset;
+import static com.nvidia.cuvs.lucene.TestUtils.generateQueries;
+
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
@@ -29,11 +32,13 @@
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.analysis.MockTokenizer;
@@ -47,9 +52,9 @@
import org.junit.Test;
@SuppressSysoutChecks(bugUrl = "prints info from within cuvs")
-public class TestCuVS extends LuceneTestCase {
+public class TestCuVSRandomizedVectorSearch extends LuceneTestCase {
- protected static Logger log = Logger.getLogger(TestCuVS.class.getName());
+ protected static Logger log = Logger.getLogger(TestCuVSRandomizedVectorSearch.class.getName());
static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new CuVSVectorsFormat());
static IndexSearcher searcher;
@@ -60,12 +65,11 @@ public class TestCuVS extends LuceneTestCase {
static int DIMENSIONS_LIMIT = 2048;
static int NUM_QUERIES_LIMIT = 10;
static int TOP_K_LIMIT = 64; // TODO This fails beyond 64
-
- public static float[][] dataset;
+ static float[][] dataset;
@BeforeClass
public static void beforeClass() throws Exception {
- assumeTrue("cuvs not supported", CuVSVectorsFormat.supported());
+ assertTrue("cuvs not supported", CuVSVectorsFormat.supported());
directory = newDirectory();
RandomIndexWriter writer =
@@ -88,7 +92,8 @@ public static void beforeClass() throws Exception {
doc.add(new StringField("id", String.valueOf(i), Field.Store.YES));
doc.add(newTextField("field", English.intToEnglish(i), Field.Store.YES));
boolean skipVector =
- random.nextInt(10) < 0; // disable testing with holes for now, there's some bug.
+ random.nextInt(10)
+ < 4; // some documents won't have vectors to test deleted/missing vectors
if (!skipVector
|| datasetSize < 100) { // about 10th of the documents shouldn't have a single vector
doc.add(new KnnFloatVectorField("vector", dataset[i], VectorSimilarityFunction.EUCLIDEAN));
@@ -146,28 +151,6 @@ public void testVectorSearch() throws IOException {
}
}
- private static float[][] generateQueries(Random random, int dimensions, int numQueries) {
- // Generate random query vectors
- float[][] queries = new float[numQueries][dimensions];
- for (int i = 0; i < numQueries; i++) {
- for (int j = 0; j < dimensions; j++) {
- queries[i][j] = random.nextFloat() * 100;
- }
- }
- return queries;
- }
-
- private static float[][] generateDataset(Random random, int datasetSize, int dimensions) {
- // Generate a random dataset
- float[][] dataset = new float[datasetSize][dimensions];
- for (int i = 0; i < datasetSize; i++) {
- for (int j = 0; j < dimensions; j++) {
- dataset[i][j] = random.nextFloat() * 100;
- }
- }
- return dataset;
- }
-
private static List> generateExpectedResults(
int topK, float[][] dataset, float[][] queries) {
List> neighborsResult = new ArrayList<>();
@@ -192,15 +175,50 @@ private static List> generateExpectedResults(
.sorted(Map.Entry.comparingByValue())
.map(Map.Entry::getKey)
.toList();
- neighborsResult.add(
- neighbors.subList(
- 0,
- Math.min(
- topK * 3,
- dataset.length))); // generate double the topK results in the expected array
+ neighborsResult.add(neighbors.subList(0, Math.min(topK * 3, dataset.length)));
}
log.info("Expected results generated successfully.");
return neighborsResult;
}
+
+ @Test
+ public void testVectorSearchWithFilter() throws IOException {
+ assertTrue("cuvs not supported", CuVSVectorsFormat.supported());
+
+ Random random = random();
+ int topK = Math.min(random.nextInt(TOP_K_LIMIT) + 1, dataset.length);
+
+ if (dataset.length < topK) topK = dataset.length;
+
+ // Find a document that has a vector by doing a search first
+ Query unfiltered = new KnnFloatVectorQuery("vector", dataset[0], 1);
+ ScoreDoc[] unfilteredHits = searcher.search(unfiltered, 1).scoreDocs;
+
+ // Skip test if no vectors found at all
+ assumeTrue(
+ "Need at least one document with vector for filtering test", unfilteredHits.length > 0);
+
+ String targetDocId = reader.storedFields().document(unfilteredHits[0].doc).get("id");
+ float[] queryVector = dataset[0];
+
+ // Create a filter that matches only the document we know has a vector
+ Query filter = new TermQuery(new Term("id", targetDocId));
+
+ // Test the new constructor with filter
+ Query filteredQuery = new CuVSKnnFloatVectorQuery("vector", queryVector, topK, filter, topK, 1);
+
+ ScoreDoc[] filteredHits = searcher.search(filteredQuery, topK).scoreDocs;
+
+ // Ensure we got some results
+ assertTrue("Should have at least one result", filteredHits.length > 0);
+
+ // Verify that all results match the filter
+ for (ScoreDoc hit : filteredHits) {
+ String docId = reader.storedFields().document(hit.doc).get("id");
+ assertEquals("All results should match the filter", targetDocId, docId);
+ }
+
+ log.info("Prefiltering test passed with " + filteredHits.length + " results");
+ }
}
diff --git a/src/test/java/com/nvidia/cuvs/lucene/TestCuVSVectorsFormat.java b/src/test/java/com/nvidia/cuvs/lucene/TestCuVSVectorsFormat.java
index d19180e..69aeb29 100644
--- a/src/test/java/com/nvidia/cuvs/lucene/TestCuVSVectorsFormat.java
+++ b/src/test/java/com/nvidia/cuvs/lucene/TestCuVSVectorsFormat.java
@@ -31,15 +31,17 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
+import org.apache.lucene.tests.util.LuceneTestCase.SuppressSysoutChecks;
import org.apache.lucene.tests.util.TestUtil;
import org.junit.BeforeClass;
import org.junit.Ignore;
+@SuppressSysoutChecks(bugUrl = "")
public class TestCuVSVectorsFormat extends BaseKnnVectorsFormatTestCase {
@BeforeClass
public static void beforeClass() {
- assumeTrue("cuvs is not supported", CuVSVectorsFormat.supported());
+ assertTrue("cuvs is not supported", CuVSVectorsFormat.supported());
}
@Override
diff --git a/src/test/java/com/nvidia/cuvs/lucene/TestIndexOutputOutputStream.java b/src/test/java/com/nvidia/cuvs/lucene/TestIndexOutputOutputStream.java
index d5f46d7..5f6ff3a 100644
--- a/src/test/java/com/nvidia/cuvs/lucene/TestIndexOutputOutputStream.java
+++ b/src/test/java/com/nvidia/cuvs/lucene/TestIndexOutputOutputStream.java
@@ -21,7 +21,9 @@
import java.util.Random;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.tests.util.LuceneTestCase.SuppressSysoutChecks;
+@SuppressSysoutChecks(bugUrl = "")
public class TestIndexOutputOutputStream extends LuceneTestCase {
public void testBasic() throws IOException {
@@ -33,8 +35,8 @@ public void testBasic() throws IOException {
out.close();
}
- try (var indexIn = dir.openInput("test", IOContext.DEFAULT)) {
- var in = new IndexInputInputStream(indexIn);
+ try (var indexIn = dir.openInput("test", IOContext.DEFAULT);
+ var in = new IndexInputInputStream(indexIn)) {
byte[] ba = new byte[6];
assertEquals(6, in.read(ba));
assertArrayEquals(new byte[] {0x56, 0x10, 0x11, 0x12, 0x13, 0x14}, ba);
@@ -76,9 +78,8 @@ public void testWithRandom() throws IOException {
out.close();
}
- try (var indexIn = dir.openInput("test", IOContext.DEFAULT)) {
- // TODO: close this stream properly in a subsequent PR.
- var in = new IndexInputInputStream(indexIn);
+ try (var indexIn = dir.openInput("test", IOContext.DEFAULT);
+ var in = new IndexInputInputStream(indexIn); ) {
int i = 0;
while (i < data.length) {
if (random.nextBoolean()) {
diff --git a/src/test/java/com/nvidia/cuvs/lucene/TestMerge.java b/src/test/java/com/nvidia/cuvs/lucene/TestMerge.java
new file mode 100644
index 0000000..a424492
--- /dev/null
+++ b/src/test/java/com/nvidia/cuvs/lucene/TestMerge.java
@@ -0,0 +1,1169 @@
+/*
+ * Copyright (c) 2025, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.cuvs.lucene;
+
+import static org.apache.lucene.tests.util.TestUtil.alwaysKnnVectorsFormat;
+
+import com.nvidia.cuvs.lucene.CuVSVectorsWriter.IndexType;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.logging.Logger;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.document.NumericDocValuesField;
+import org.apache.lucene.document.SortedDocValuesField;
+import org.apache.lucene.document.StringField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.SortedDocValues;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.index.TieredMergePolicy;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.Sort;
+import org.apache.lucene.search.SortField;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.tests.util.LuceneTestCase.SuppressSysoutChecks;
+import org.apache.lucene.util.BytesRef;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+/**
+ * Comprehensive tests for merge functionality with CuVS indexes.
+ * Tests merge operations across different index types including brute force,
+ * CAGRA, and combined index configurations to ensure proper vector handling
+ * and search functionality after segment merging.
+ */
+@SuppressSysoutChecks(bugUrl = "CuVS native library produces verbose logging output")
+public class TestMerge extends LuceneTestCase {
+
+ private static final Logger log = Logger.getLogger(TestMerge.class.getName());
+
+ private static final int MIN_VECTOR_DIMENSION = 64;
+ private static final int MAX_VECTOR_DIMENSION = 256;
+ private static final int TOP_K_LIMIT = 64;
+
+ @BeforeClass
+ public static void beforeClass() {
+ assertTrue("cuVS is not supported", CuVSVectorsFormat.supported());
+ }
+
+ private Directory directory;
+ private int vectorDimension;
+
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ directory = newDirectory();
+
+ // Randomize vector dimension for each test
+ vectorDimension =
+ MIN_VECTOR_DIMENSION + random().nextInt(MAX_VECTOR_DIMENSION - MIN_VECTOR_DIMENSION + 1);
+ // Ensure dimension is multiple of 4 for better performance
+ vectorDimension = (vectorDimension / 4) * 4;
+
+ log.info("Using randomized vector dimension: " + vectorDimension);
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ if (directory != null) {
+ directory.close();
+ }
+ super.tearDown();
+ }
+
+ /**
+ * Test merging many documents across multiple segments
+ **/
+ @Test
+ public void testMergeManyDocumentsMultipleSegments() throws IOException {
+ log.info("Starting testMergeManyDocumentsMultipleSegments");
+
+ // Randomize configuration parameters
+ int maxBufferedDocs = 5 + random().nextInt(16); // 5-20 docs per buffer
+ int totalBatches = 8 + random().nextInt(8); // 8-15 batches
+ int docsPerBatch = 15 + random().nextInt(11); // 15-25 docs per batch
+ int totalDocuments = totalBatches * docsPerBatch;
+
+ // Randomize vector presence probability (60-85%)
+ double vectorProbability = 0.6 + (random().nextDouble() * 0.25);
+
+ log.info(
+ "Randomized parameters: maxBufferedDocs="
+ + maxBufferedDocs
+ + ", totalBatches="
+ + totalBatches
+ + ", docsPerBatch="
+ + docsPerBatch
+ + ", totalDocuments="
+ + totalDocuments
+ + ", vectorProbability="
+ + vectorProbability);
+
+ IndexWriterConfig config =
+ new IndexWriterConfig()
+ .setCodec(alwaysKnnVectorsFormat(new CuVSVectorsFormat()))
+ .setMaxBufferedDocs(maxBufferedDocs) // Randomized buffer size
+ .setRAMBufferSizeMB(IndexWriterConfig.DISABLE_AUTO_FLUSH);
+
+ List expectedVectors = new ArrayList<>();
+ List expectedDocIds = new ArrayList<>();
+ int documentsWithVectors = 0;
+
+ try (IndexWriter writer = new IndexWriter(directory, config)) {
+ // Add documents in multiple batches to create many segments
+ for (int batch = 0; batch < totalBatches; batch++) {
+ for (int i = 0; i < docsPerBatch; i++) {
+ int docId = batch * docsPerBatch + i;
+ Document doc = new Document();
+ doc.add(new StringField("id", String.valueOf(docId), Field.Store.YES));
+ doc.add(new NumericDocValuesField("batch", batch));
+
+ // Randomly decide if document has vector
+ if (random().nextDouble() < vectorProbability) {
+ float[] vector = generateRandomVector(vectorDimension, random());
+ doc.add(new KnnFloatVectorField("vector", vector, VectorSimilarityFunction.COSINE));
+ expectedVectors.add(vector);
+ expectedDocIds.add(docId);
+ documentsWithVectors++;
+ }
+
+ writer.addDocument(doc);
+ }
+ writer.commit(); // Create a new segment
+ }
+
+ int documentsWithoutVectors = totalDocuments - documentsWithVectors;
+ log.info("Created " + totalDocuments + " documents in " + totalBatches + " segments");
+ log.info("Documents with vectors: " + documentsWithVectors);
+ log.info("Documents without vectors: " + documentsWithoutVectors);
+
+ // Force merge to trigger merge logic
+ writer.forceMerge(1);
+ log.info("Forced merge to single segment completed");
+ }
+
+ // Verify the merged index
+ try (DirectoryReader reader = DirectoryReader.open(directory)) {
+ assertEquals("Should have exactly one segment after merge", 1, reader.leaves().size());
+
+ LeafReader leafReader = reader.leaves().get(0).reader();
+ assertEquals("Total documents should match", totalDocuments, leafReader.maxDoc());
+
+ // Verify vector search works correctly after merge
+ if (documentsWithVectors > 0) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ float[] queryVector = generateRandomVector(vectorDimension, random());
+
+ // Randomize search parameters
+ int searchK =
+ Math.min(5 + random().nextInt(10), Math.min(documentsWithVectors, TOP_K_LIMIT));
+
+ KnnFloatVectorQuery query = new KnnFloatVectorQuery("vector", queryVector, searchK);
+ TopDocs results = searcher.search(query, searchK);
+
+ assertTrue("Should find some results after merge", results.scoreDocs.length > 0);
+ assertTrue(
+ "Should find reasonable number of results",
+ results.scoreDocs.length <= documentsWithVectors);
+
+ log.info(
+ "Vector search returned "
+ + results.scoreDocs.length
+ + " results out of "
+ + documentsWithVectors
+ + " documents with vectors");
+
+ // Verify all returned documents have valid IDs
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ int docId = Integer.parseInt(searcher.storedFields().document(scoreDoc.doc).get("id"));
+ assertTrue("Document ID should be valid", docId >= 0 && docId < totalDocuments);
+ }
+ } else {
+ log.info("No documents with vectors - skipping vector search verification");
+ }
+
+ log.info("Merge verification completed successfully");
+ }
+ }
+
+ /**
+ * Test merging with index sorting enabled using text-based sorting and SortingMergePolicy
+ **/
+ @Test
+ public void testMergeWithIndexSorting() throws IOException {
+ log.info("Starting testMergeWithIndexSorting with text-based sorting");
+
+ // Randomize sort field type
+ SortField.Type sortType = random().nextBoolean() ? SortField.Type.STRING : SortField.Type.LONG;
+ String sortFieldName = sortType == SortField.Type.STRING ? "text_sort_key" : "numeric_sort_key";
+
+ // Configure index sorting by a randomized field
+ Sort indexSort = new Sort(new SortField(sortFieldName, sortType));
+
+ // Randomize merge policy parameters
+ TieredMergePolicy mergePolicy = new TieredMergePolicy();
+ mergePolicy.setMaxMergedSegmentMB(128 + random().nextInt(257)); // 128-384 MB
+ mergePolicy.setSegmentsPerTier(3 + random().nextInt(4)); // 3-6 segments per tier
+
+ // Randomize writer configuration parameters
+ int maxBufferedDocs = 10 + random().nextInt(16); // 10-25 docs per buffer
+ int totalDocuments = 80 + random().nextInt(81); // 80-160 documents
+ int segmentSize = 15 + random().nextInt(11); // 15-25 docs per segment
+ double vectorProbability = 0.65 + (random().nextDouble() * 0.25); // 65-90% have vectors
+
+ log.info(
+ "Randomized sorting parameters: sortType=" + sortType + ", sortFieldName=" + sortFieldName);
+ log.info(
+ "Randomized config: maxBufferedDocs="
+ + maxBufferedDocs
+ + ", totalDocuments="
+ + totalDocuments
+ + ", segmentSize="
+ + segmentSize
+ + ", vectorProbability="
+ + vectorProbability);
+
+ IndexWriterConfig config =
+ new IndexWriterConfig()
+ .setCodec(alwaysKnnVectorsFormat(new CuVSVectorsFormat()))
+ .setIndexSort(indexSort) // This automatically enables sorting during merges
+ .setMergePolicy(mergePolicy)
+ .setMaxBufferedDocs(maxBufferedDocs)
+ .setRAMBufferSizeMB(IndexWriterConfig.DISABLE_AUTO_FLUSH);
+
+ // List documents = new ArrayList<>();
+
+ try (IndexWriter writer = new IndexWriter(directory, config)) {
+ // Create documents with randomized sort keys
+ for (int i = 0; i < totalDocuments; i++) {
+ float[] vector = null;
+
+ // Randomly decide if document has vector
+ if (random().nextDouble() < vectorProbability) {
+ vector = generateRandomVector(vectorDimension, random());
+ }
+
+ Document doc = new Document();
+ doc.add(new StringField("id", String.valueOf(i), Field.Store.YES));
+ doc.add(new StringField("original_order", String.valueOf(i), Field.Store.YES));
+
+ // Add sort field based on randomized type
+ if (sortType == SortField.Type.STRING) {
+ // Randomize text sort key length (4-12 characters)
+ int keyLength = 4 + random().nextInt(9);
+ String textSortKey = generateRandomText(random(), keyLength);
+ doc.add(new SortedDocValuesField(sortFieldName, new BytesRef(textSortKey)));
+ doc.add(new StringField(sortFieldName + "_stored", textSortKey, Field.Store.YES));
+ } else {
+ // Use numeric sort key with wider range
+ long numericSortKey = random().nextLong() % 100000; // Can be negative for more variety
+ doc.add(new NumericDocValuesField(sortFieldName, numericSortKey));
+ doc.add(
+ new StringField(
+ sortFieldName + "_stored", String.valueOf(numericSortKey), Field.Store.YES));
+ }
+
+ if (vector != null) {
+ doc.add(new KnnFloatVectorField("vector", vector, VectorSimilarityFunction.COSINE));
+ }
+
+ writer.addDocument(doc);
+
+ // Commit based on randomized segment size
+ if ((i + 1) % segmentSize == 0) {
+ writer.commit();
+ log.info(
+ "Committed segment "
+ + ((i + 1) / segmentSize)
+ + " with "
+ + (i + 1)
+ + " total documents");
+ }
+ }
+
+ log.info("Created " + totalDocuments + " documents with text-based index sorting");
+
+ // Force merge with sorting - this will use the sorting merge policy
+ writer.forceMerge(1);
+ log.info("Forced merge with text-based sorting completed");
+ }
+
+ // Verify the merged and sorted index
+ try (DirectoryReader reader = DirectoryReader.open(directory)) {
+ assertEquals("Should have exactly one segment after merge", 1, reader.leaves().size());
+
+ LeafReader leafReader = reader.leaves().get(0).reader();
+ assertEquals("Total documents should match", totalDocuments, leafReader.maxDoc());
+
+ // Verify documents are sorted correctly by the randomized sort field
+ log.info(
+ "Verifying document sorting order using sortType: "
+ + sortType
+ + ", field: "
+ + sortFieldName);
+
+ if (sortType == SortField.Type.STRING) {
+ // Verify string-based sorting
+ String previousSortKey = "";
+ SortedDocValues sortedValues = leafReader.getSortedDocValues(sortFieldName);
+
+ for (int docId = 0; docId < leafReader.maxDoc(); docId++) {
+ String currentSortKey = "";
+ if (sortedValues != null && sortedValues.advanceExact(docId)) {
+ currentSortKey = sortedValues.lookupOrd(sortedValues.ordValue()).utf8ToString();
+ }
+
+ assertTrue(
+ "Documents should be sorted by "
+ + sortFieldName
+ + ": '"
+ + previousSortKey
+ + "' should be <= '"
+ + currentSortKey
+ + "'",
+ previousSortKey.compareTo(currentSortKey) <= 0);
+ previousSortKey = currentSortKey;
+
+ // Log first 10 documents to verify sorting
+ if (docId < 10) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ String originalOrder = searcher.storedFields().document(docId).get("original_order");
+ log.info(
+ "DocId: "
+ + docId
+ + ", OriginalOrder: "
+ + originalOrder
+ + ", SortKey: '"
+ + currentSortKey
+ + "'");
+ }
+ }
+ } else {
+ // Verify numeric-based sorting
+ long previousSortKey = Long.MIN_VALUE;
+ var numericValues = leafReader.getNumericDocValues(sortFieldName);
+
+ for (int docId = 0; docId < leafReader.maxDoc(); docId++) {
+ long currentSortKey = Long.MIN_VALUE;
+ if (numericValues != null && numericValues.advanceExact(docId)) {
+ currentSortKey = numericValues.longValue();
+ }
+
+ assertTrue(
+ "Documents should be sorted by "
+ + sortFieldName
+ + ": "
+ + previousSortKey
+ + " should be <= "
+ + currentSortKey,
+ previousSortKey <= currentSortKey);
+ previousSortKey = currentSortKey;
+
+ // Log first 10 documents to verify sorting
+ if (docId < 10) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ String originalOrder = searcher.storedFields().document(docId).get("original_order");
+ log.info(
+ "DocId: "
+ + docId
+ + ", OriginalOrder: "
+ + originalOrder
+ + ", SortKey: "
+ + currentSortKey);
+ }
+ }
+ }
+
+ // Count total vectors by checking if vector field exists and has values
+ var vectorValues = leafReader.getFloatVectorValues("vector");
+ int documentsWithVectors = vectorValues != null ? vectorValues.size() : 0;
+
+ log.info("Found " + documentsWithVectors + " documents with vectors after sorted merge");
+
+ // Test vector search on sorted index
+ if (documentsWithVectors > 0) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ float[] queryVector = generateRandomVector(vectorDimension, random());
+
+ KnnFloatVectorQuery query =
+ new KnnFloatVectorQuery("vector", queryVector, Math.min(10, documentsWithVectors));
+ TopDocs results = searcher.search(query, 10);
+
+ assertTrue("Should find results in sorted index", results.scoreDocs.length > 0);
+ log.info("Vector search on sorted index returned " + results.scoreDocs.length + " results");
+
+ // Verify that returned documents maintain sort order if we check their sort keys
+ log.info("Verifying vector search results maintain sorting consistency...");
+ for (int i = 0; i < Math.min(3, results.scoreDocs.length); i++) {
+ ScoreDoc scoreDoc = results.scoreDocs[i];
+ String originalOrder =
+ searcher.storedFields().document(scoreDoc.doc).get("original_order");
+ String sortKey =
+ searcher.storedFields().document(scoreDoc.doc).get(sortFieldName + "_stored");
+ log.info(
+ "Result "
+ + i
+ + ": DocId="
+ + scoreDoc.doc
+ + ", OriginalOrder="
+ + originalOrder
+ + ", SortKey='"
+ + sortKey
+ + "', Score="
+ + scoreDoc.score);
+ }
+ }
+
+ log.info("Text-based index sorting verification completed successfully");
+ }
+ }
+
+ /**
+ * Test merging segments with various patterns of missing vectors
+ **/
+ @Test
+ public void testMergeWithMissingVectors() throws IOException {
+ log.info("Starting testMergeWithMissingVectors");
+
+ // Randomize configuration
+ int maxBufferedDocs = 10 + random().nextInt(11); // 10-20 docs per buffer
+ int numSegments = 3 + random().nextInt(3); // 3-5 segments
+
+ IndexWriterConfig config =
+ new IndexWriterConfig()
+ .setCodec(alwaysKnnVectorsFormat(new CuVSVectorsFormat()))
+ .setMaxBufferedDocs(maxBufferedDocs)
+ .setRAMBufferSizeMB(IndexWriterConfig.DISABLE_AUTO_FLUSH);
+
+ log.info(
+ "Randomized parameters: maxBufferedDocs="
+ + maxBufferedDocs
+ + ", numSegments="
+ + numSegments);
+
+ int totalExpectedVectors = 0;
+ int totalDocuments = 0;
+
+ try (IndexWriter writer = new IndexWriter(directory, config)) {
+ for (int seg = 0; seg < numSegments; seg++) {
+ // Randomize segment characteristics
+ int docsInSegment = 15 + random().nextInt(16); // 15-30 docs per segment
+ double vectorProbability = random().nextDouble(); // 0-100% vector probability
+ String segmentType = "seg_" + seg + "_prob_" + String.format("%.2f", vectorProbability);
+
+ int segmentVectorCount = 0;
+
+ for (int i = 0; i < docsInSegment; i++) {
+ Document doc = new Document();
+ doc.add(new StringField("id", "seg" + seg + "_" + i, Field.Store.YES));
+ doc.add(new StringField("segment", segmentType, Field.Store.YES));
+ doc.add(new NumericDocValuesField("segment_num", seg));
+ doc.add(new NumericDocValuesField("doc_in_segment", i));
+
+ // Randomly add vector based on segment's probability
+ if (random().nextDouble() < vectorProbability) {
+ float[] vector = generateRandomVector(vectorDimension, random());
+ doc.add(new KnnFloatVectorField("vector", vector, VectorSimilarityFunction.COSINE));
+ segmentVectorCount++;
+ }
+
+ writer.addDocument(doc);
+ }
+
+ writer.commit();
+ totalDocuments += docsInSegment;
+ totalExpectedVectors += segmentVectorCount;
+
+ log.info(
+ "Created segment "
+ + seg
+ + ": "
+ + docsInSegment
+ + " documents, "
+ + segmentVectorCount
+ + " with vectors (probability: "
+ + String.format("%.2f", vectorProbability)
+ + ")");
+ }
+
+ // Force merge all segments
+ writer.forceMerge(1);
+ log.info("Forced merge of " + numSegments + " segments completed");
+ }
+
+ // Verify the merged index handles missing vectors correctly
+ try (DirectoryReader reader = DirectoryReader.open(directory)) {
+ assertEquals("Should have exactly one segment after merge", 1, reader.leaves().size());
+
+ LeafReader leafReader = reader.leaves().get(0).reader();
+ assertEquals("Total documents should match", totalDocuments, leafReader.maxDoc());
+
+ // Count actual vectors in merged index
+ var vectorValues = leafReader.getFloatVectorValues("vector");
+ int actualVectorCount = vectorValues != null ? vectorValues.size() : 0;
+
+ log.info(
+ "Total documents: "
+ + totalDocuments
+ + ", Expected vectors: "
+ + totalExpectedVectors
+ + ", Actual vectors: "
+ + actualVectorCount);
+
+ assertEquals("Vector count should match expected", totalExpectedVectors, actualVectorCount);
+
+ // Test vector search if we have vectors
+ if (actualVectorCount > 0) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ float[] queryVector = generateRandomVector(vectorDimension, random());
+
+ // Randomize search parameters
+ int searchK = Math.min(5 + random().nextInt(10), Math.min(actualVectorCount, TOP_K_LIMIT));
+
+ KnnFloatVectorQuery vectorQuery = new KnnFloatVectorQuery("vector", queryVector, searchK);
+ TopDocs vectorResults = searcher.search(vectorQuery, searchK);
+
+ assertTrue("Should find some vector results", vectorResults.scoreDocs.length > 0);
+ assertTrue(
+ "Should not find more vectors than exist",
+ vectorResults.scoreDocs.length <= actualVectorCount);
+
+ log.info(
+ "Found "
+ + vectorResults.scoreDocs.length
+ + " vector results out of "
+ + actualVectorCount
+ + " available");
+ } else {
+ log.info("No vectors in merged index - skipping vector search");
+ }
+
+ log.info("Missing vectors test completed successfully");
+ }
+ }
+
+ /**
+ * Test merge behavior with document deletions
+ **/
+ @Test
+ public void testMergeWithDeletions() throws IOException {
+ log.info("Starting testMergeWithDeletions");
+
+ // Randomize configuration parameters
+ int maxBufferedDocs = 15 + random().nextInt(11); // 15-25 docs per buffer
+ int numSegments = 3 + random().nextInt(4); // 3-6 segments
+ int docsPerSegment = 20 + random().nextInt(21); // 20-40 docs per segment
+ double vectorProbability = 0.7 + (random().nextDouble() * 0.25); // 70-95% have vectors
+ double deletionProbability = 0.2 + (random().nextDouble() * 0.3); // 20-50% deletion rate
+
+ log.info(
+ "Randomized parameters: maxBufferedDocs="
+ + maxBufferedDocs
+ + ", numSegments="
+ + numSegments
+ + ", docsPerSegment="
+ + docsPerSegment
+ + ", vectorProbability="
+ + vectorProbability
+ + ", deletionProbability="
+ + deletionProbability);
+
+ IndexWriterConfig config =
+ new IndexWriterConfig()
+ .setCodec(alwaysKnnVectorsFormat(new CuVSVectorsFormat()))
+ .setMaxBufferedDocs(maxBufferedDocs)
+ .setRAMBufferSizeMB(IndexWriterConfig.DISABLE_AUTO_FLUSH);
+
+ List expectedRemainingDocs = new ArrayList<>();
+ List deletedDocs = new ArrayList<>();
+ int totalDocuments = numSegments * docsPerSegment;
+
+ try (IndexWriter writer = new IndexWriter(directory, config)) {
+ // Create multiple segments with documents
+ for (int seg = 0; seg < numSegments; seg++) {
+ for (int i = 0; i < docsPerSegment; i++) {
+ int docId = seg * docsPerSegment + i;
+ Document doc = new Document();
+ doc.add(new StringField("id", String.valueOf(docId), Field.Store.YES));
+ doc.add(new StringField("segment", "seg_" + seg, Field.Store.YES));
+ doc.add(new NumericDocValuesField("doc_num", docId));
+ doc.add(new NumericDocValuesField("segment_num", seg));
+
+ // Randomly add vectors
+ if (random().nextDouble() < vectorProbability) {
+ float[] vector = generateRandomVector(vectorDimension, random());
+ doc.add(new KnnFloatVectorField("vector", vector, VectorSimilarityFunction.COSINE));
+ }
+
+ writer.addDocument(doc);
+ }
+ writer.commit();
+ }
+
+ log.info(
+ "Created "
+ + numSegments
+ + " segments with "
+ + docsPerSegment
+ + " documents each ("
+ + totalDocuments
+ + " total)");
+
+ // Delete documents randomly and track which ones are deleted
+ int deletedCount = 0;
+ for (int docId = 0; docId < totalDocuments; docId++) {
+ if (random().nextDouble() < deletionProbability) {
+ writer.deleteDocuments(new Term("id", String.valueOf(docId)));
+ deletedDocs.add(docId);
+ deletedCount++;
+ } else {
+ expectedRemainingDocs.add(docId);
+ }
+ }
+
+ log.info(
+ "Deleted "
+ + deletedCount
+ + " documents ("
+ + String.format("%.1f", (100.0 * deletedCount / totalDocuments))
+ + "%), remaining: "
+ + expectedRemainingDocs.size());
+
+ writer.commit();
+
+ // Force merge to apply deletions
+ writer.forceMerge(1);
+ log.info("Forced merge with deletions completed");
+ }
+
+ // Verify the merged index correctly handles deletions
+ try (DirectoryReader reader = DirectoryReader.open(directory)) {
+ assertEquals("Should have exactly one segment after merge", 1, reader.leaves().size());
+
+ LeafReader leafReader = reader.leaves().get(0).reader();
+ int expectedRemaining = expectedRemainingDocs.size();
+ assertEquals(
+ "Should have correct number of documents after deletions",
+ expectedRemaining,
+ leafReader.maxDoc());
+
+ // Verify that deleted documents are not present
+ IndexSearcher searcher = new IndexSearcher(reader);
+
+ // Test that we can find expected remaining documents
+ for (int i = 0; i < Math.min(10, expectedRemainingDocs.size()); i++) {
+ int docId = expectedRemainingDocs.get(i);
+ TopDocs result = searcher.search(new TermQuery(new Term("id", String.valueOf(docId))), 1);
+ assertEquals("Should find remaining document " + docId, 1, (int) result.totalHits.value());
+ }
+
+ // Test that actually deleted documents are not found
+ int deletedDocsToCheck = Math.min(10, deletedDocs.size()); // Check up to 10 deleted docs
+ for (int i = 0; i < deletedDocsToCheck; i++) {
+ int docId = deletedDocs.get(i);
+ TopDocs result = searcher.search(new TermQuery(new Term("id", String.valueOf(docId))), 1);
+ assertEquals(
+ "Should not find deleted document " + docId, 0, (int) result.totalHits.value());
+ }
+
+ // Test vector search works after deletions
+ float[] queryVector = generateRandomVector(vectorDimension, random());
+ KnnFloatVectorQuery vectorQuery = new KnnFloatVectorQuery("vector", queryVector, 10);
+ TopDocs vectorResults = searcher.search(vectorQuery, 10);
+
+ assertTrue(
+ "Should find some vector results after deletions", vectorResults.scoreDocs.length > 0);
+
+ log.info("Found " + vectorResults.scoreDocs.length + " vector results after deletions");
+ log.info("Deletion merge verification completed successfully");
+ }
+ }
+
+ /**
+ * Test merging segments for {@link IndexType#BRUTE_FORCE}
+ * */
+ @Test
+ public void testMergeBruteForceIndex() throws IOException {
+ log.info("Starting testMergeBruteForceIndex");
+
+ // Randomize configuration parameters
+ int maxBufferedDocs = 8 + random().nextInt(8); // 8-15 docs per buffer
+ int numSegments = 3 + random().nextInt(3); // 3-5 segments
+ int docsPerSegment = 12 + random().nextInt(9); // 12-20 docs per segment
+ double vectorProbability = 0.8 + (random().nextDouble() * 0.2); // 80-100% have vectors
+
+ log.info(
+ "Randomized parameters: maxBufferedDocs="
+ + maxBufferedDocs
+ + ", numSegments="
+ + numSegments
+ + ", docsPerSegment="
+ + docsPerSegment
+ + ", vectorProbability="
+ + vectorProbability);
+
+ // Configure with brute force index type
+ CuVSVectorsFormat bruteForceFormat =
+ new CuVSVectorsFormat(
+ 32, // writer threads
+ 128, // intermediate graph degree
+ 64, // graph degree
+ IndexType.BRUTE_FORCE); // Use brute force index
+
+ IndexWriterConfig config =
+ new IndexWriterConfig()
+ .setCodec(alwaysKnnVectorsFormat(bruteForceFormat))
+ .setMaxBufferedDocs(maxBufferedDocs)
+ .setRAMBufferSizeMB(IndexWriterConfig.DISABLE_AUTO_FLUSH);
+
+ int totalDocuments = numSegments * docsPerSegment;
+ int totalExpectedVectors = 0;
+
+ try (IndexWriter writer = new IndexWriter(directory, config)) {
+ // Create multiple segments with brute force index
+ for (int seg = 0; seg < numSegments; seg++) {
+ int segmentVectorCount = 0;
+
+ for (int i = 0; i < docsPerSegment; i++) {
+ int docId = seg * docsPerSegment + i;
+ Document doc = new Document();
+ doc.add(new StringField("id", String.valueOf(docId), Field.Store.YES));
+ doc.add(new StringField("segment", "seg_" + seg, Field.Store.YES));
+ doc.add(new NumericDocValuesField("segment_num", seg));
+ doc.add(new NumericDocValuesField("doc_in_segment", i));
+
+ // Randomly add vectors based on probability
+ if (random().nextDouble() < vectorProbability) {
+ float[] vector = generateRandomVector(vectorDimension, random());
+ doc.add(new KnnFloatVectorField("vector", vector, VectorSimilarityFunction.COSINE));
+ segmentVectorCount++;
+ }
+
+ writer.addDocument(doc);
+ }
+
+ writer.commit();
+ totalExpectedVectors += segmentVectorCount;
+
+ log.info(
+ "Created brute force segment "
+ + seg
+ + ": "
+ + docsPerSegment
+ + " documents, "
+ + segmentVectorCount
+ + " with vectors");
+ }
+
+ log.info(
+ "Created "
+ + numSegments
+ + " brute force segments with "
+ + totalDocuments
+ + " total documents and "
+ + totalExpectedVectors
+ + " vectors");
+
+ // Force merge all brute force segments
+ writer.forceMerge(1);
+ log.info("Forced merge of brute force segments completed");
+ }
+
+ // Verify the merged brute force index
+ try (DirectoryReader reader = DirectoryReader.open(directory)) {
+ assertEquals("Should have exactly one segment after merge", 1, reader.leaves().size());
+
+ LeafReader leafReader = reader.leaves().get(0).reader();
+ assertEquals("Total documents should match", totalDocuments, leafReader.maxDoc());
+
+ // Count actual vectors in merged index
+ var vectorValues = leafReader.getFloatVectorValues("vector");
+ int actualVectorCount = vectorValues != null ? vectorValues.size() : 0;
+
+ log.info(
+ "Brute force merge results: Total documents: "
+ + totalDocuments
+ + ", Expected vectors: "
+ + totalExpectedVectors
+ + ", Actual vectors: "
+ + actualVectorCount);
+
+ assertEquals("Vector count should match expected", totalExpectedVectors, actualVectorCount);
+
+ // Test brute force vector search (exact search)
+ if (actualVectorCount > 0) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ float[] queryVector = generateRandomVector(vectorDimension, random());
+
+ // Search for reasonable number of results
+ int searchK = Math.min(8 + random().nextInt(8), Math.min(actualVectorCount, TOP_K_LIMIT));
+
+ KnnFloatVectorQuery vectorQuery = new KnnFloatVectorQuery("vector", queryVector, searchK);
+ TopDocs vectorResults = searcher.search(vectorQuery, searchK);
+
+ assertTrue(
+ "Should find some vector results in brute force index",
+ vectorResults.scoreDocs.length > 0);
+ assertTrue(
+ "Should not find more vectors than exist",
+ vectorResults.scoreDocs.length <= actualVectorCount);
+
+ log.info(
+ "Brute force search found "
+ + vectorResults.scoreDocs.length
+ + " results out of "
+ + actualVectorCount
+ + " available vectors");
+
+ // Verify all returned documents are valid
+ for (ScoreDoc scoreDoc : vectorResults.scoreDocs) {
+ String docId = searcher.storedFields().document(scoreDoc.doc).get("id");
+ assertNotNull("Document should have valid ID", docId);
+ assertTrue("Score should be positive", scoreDoc.score > 0);
+ }
+ } else {
+ log.info("No vectors in brute force merged index - skipping vector search");
+ }
+
+ log.info("Brute force merge verification completed successfully");
+ }
+ }
+
+ /**
+ * Test merging segments for {@link IndexType#CAGRA_AND_BRUTE_FORCE}
+ * */
+ @Test
+ public void testMergeCagraAndBruteForceIndex() throws IOException {
+ log.info("Starting testMergeCagraAndBruteForceIndex");
+
+ // Use moderate dataset size
+ int maxBufferedDocs = 15 + random().nextInt(10); // 15-24 docs per buffer
+ int numSegments =
+ 4; // Fixed 4 segments: alternating CAGRA vs small segments (brute force fallback)
+ int docsPerSegment = 20 + random().nextInt(11); // 20-30 docs per segment
+ double vectorProbability = 0.9 + (random().nextDouble() * 0.1); // 90-100% have vectors
+
+ log.info(
+ "Randomized parameters: maxBufferedDocs="
+ + maxBufferedDocs
+ + ", numSegments="
+ + numSegments
+ + ", docsPerSegment="
+ + docsPerSegment
+ + ", vectorProbability="
+ + vectorProbability);
+
+ // Configure with CAGRA + brute force combined index type
+ CuVSVectorsFormat combinedFormat =
+ new CuVSVectorsFormat(
+ 32, // writer threads
+ 128, // intermediate graph degree
+ 64, // graph degree
+ IndexType.CAGRA_AND_BRUTE_FORCE); // Use combined CAGRA + brute force
+
+ IndexWriterConfig config =
+ new IndexWriterConfig()
+ .setCodec(alwaysKnnVectorsFormat(combinedFormat))
+ .setMaxBufferedDocs(maxBufferedDocs)
+ .setRAMBufferSizeMB(IndexWriterConfig.DISABLE_AUTO_FLUSH);
+
+ int totalDocuments = numSegments * docsPerSegment;
+ int totalExpectedVectors = 0;
+
+ try (IndexWriter writer = new IndexWriter(directory, config)) {
+ // Create segments that will result in mixed index types during merge
+ for (int seg = 0; seg < numSegments; seg++) {
+ int segmentVectorCount = 0;
+
+ for (int i = 0; i < docsPerSegment; i++) {
+ int docId = seg * docsPerSegment + i;
+ Document doc = new Document();
+ doc.add(new StringField("id", String.valueOf(docId), Field.Store.YES));
+ doc.add(new StringField("segment", "mixed_seg_" + seg, Field.Store.YES));
+ doc.add(new StringField("index_type", "cagra_and_brute_force", Field.Store.YES));
+ doc.add(new NumericDocValuesField("segment_num", seg));
+ doc.add(new NumericDocValuesField("doc_in_segment", i));
+
+ // Add vectors based on probability
+ if (random().nextDouble() < vectorProbability) {
+ float[] vector = generateRandomVector(vectorDimension, random());
+ doc.add(new KnnFloatVectorField("vector", vector, VectorSimilarityFunction.COSINE));
+ segmentVectorCount++;
+ }
+
+ writer.addDocument(doc);
+ }
+
+ writer.commit();
+ totalExpectedVectors += segmentVectorCount;
+
+ log.info(
+ "Created CAGRA+brute force segment "
+ + seg
+ + ": "
+ + docsPerSegment
+ + " documents, "
+ + segmentVectorCount
+ + " with vectors");
+ }
+
+ log.info(
+ "Created "
+ + numSegments
+ + " CAGRA+brute force segments with "
+ + totalDocuments
+ + " total documents and "
+ + totalExpectedVectors
+ + " vectors");
+
+ // Force merge all CAGRA+brute force segments
+ writer.forceMerge(1);
+ log.info("Forced merge of CAGRA+brute force segments completed");
+ }
+
+ // Verify the merged CAGRA+brute force index
+ try (DirectoryReader reader = DirectoryReader.open(directory)) {
+ assertEquals("Should have exactly one segment after merge", 1, reader.leaves().size());
+
+ LeafReader leafReader = reader.leaves().get(0).reader();
+ assertEquals("Total documents should match", totalDocuments, leafReader.maxDoc());
+
+ // Count actual vectors in merged index
+ var vectorValues = leafReader.getFloatVectorValues("vector");
+ int actualVectorCount = vectorValues != null ? vectorValues.size() : 0;
+
+ log.info(
+ "CAGRA+brute force merge results: Total documents: "
+ + totalDocuments
+ + ", Expected vectors: "
+ + totalExpectedVectors
+ + ", Actual vectors: "
+ + actualVectorCount);
+
+ assertEquals("Vector count should match expected", totalExpectedVectors, actualVectorCount);
+
+ // Test CAGRA+brute force index vector search
+ if (actualVectorCount > 0) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ float[] queryVector = generateRandomVector(vectorDimension, random());
+
+ // Search for reasonable number of results
+ int searchK = Math.min(12 + random().nextInt(8), Math.min(actualVectorCount, TOP_K_LIMIT));
+
+ KnnFloatVectorQuery vectorQuery = new KnnFloatVectorQuery("vector", queryVector, searchK);
+ TopDocs vectorResults = searcher.search(vectorQuery, searchK);
+
+ assertTrue(
+ "Should find some vector results in CAGRA+brute force index",
+ vectorResults.scoreDocs.length > 0);
+ assertTrue(
+ "Should not find more vectors than exist",
+ vectorResults.scoreDocs.length <= actualVectorCount);
+
+ log.info(
+ "CAGRA+brute force index search found "
+ + vectorResults.scoreDocs.length
+ + " results out of "
+ + actualVectorCount
+ + " available vectors");
+
+ // Verify all returned documents are valid and have expected metadata
+ for (ScoreDoc scoreDoc : vectorResults.scoreDocs) {
+ Document resultDoc = searcher.storedFields().document(scoreDoc.doc);
+ String docId = resultDoc.get("id");
+ String indexType = resultDoc.get("index_type");
+
+ assertNotNull("Document should have valid ID", docId);
+ assertEquals(
+ "Document should be marked as CAGRA+brute force index type",
+ "cagra_and_brute_force",
+ indexType);
+ assertTrue("Score should be positive", scoreDoc.score > 0);
+ }
+
+ // Test that the CAGRA+brute force index handles both approximate and exact search
+ // consistently
+ for (int trial = 0; trial < 3; trial++) {
+ float[] trialQueryVector = generateRandomVector(vectorDimension, random());
+ KnnFloatVectorQuery trialQuery =
+ new KnnFloatVectorQuery("vector", trialQueryVector, Math.min(5, actualVectorCount));
+ TopDocs trialResults = searcher.search(trialQuery, Math.min(5, actualVectorCount));
+
+ assertTrue("Trial " + trial + " should find results", trialResults.scoreDocs.length > 0);
+ log.info("Trial " + trial + " found " + trialResults.scoreDocs.length + " results");
+ }
+ } else {
+ log.info("No vectors in CAGRA+brute force merged index - skipping vector search");
+ }
+
+ log.info("CAGRA+brute force merge verification completed successfully");
+ }
+ }
+
+ /**
+ * Test large scale merge to stress test the system
+ **/
+ @Test
+ public void testLargeScaleMerge() throws IOException {
+ assumeTrue(
+ "testLargeScaleMerge requires -DlargeScale=true",
+ Boolean.parseBoolean(System.getProperty("largeScale", "false")));
+
+ log.info("Starting testLargeScaleMerge");
+
+ // Randomize large scale parameters
+ int maxBufferedDocs = 40 + random().nextInt(21); // 40-60 docs per buffer
+ int segmentCount = 15 + random().nextInt(11); // 15-25 segments
+ int docsPerSegment = 30 + random().nextInt(21); // 30-50 docs per segment
+ int totalDocuments = segmentCount * docsPerSegment;
+
+ log.info(
+ "Randomized large scale parameters: maxBufferedDocs="
+ + maxBufferedDocs
+ + ", segmentCount="
+ + segmentCount
+ + ", docsPerSegment="
+ + docsPerSegment
+ + ", totalDocuments="
+ + totalDocuments);
+
+ IndexWriterConfig config =
+ new IndexWriterConfig()
+ .setCodec(alwaysKnnVectorsFormat(new CuVSVectorsFormat()))
+ .setMaxBufferedDocs(maxBufferedDocs)
+ .setRAMBufferSizeMB(IndexWriterConfig.DISABLE_AUTO_FLUSH);
+
+ try (IndexWriter writer = new IndexWriter(directory, config)) {
+ for (int seg = 0; seg < segmentCount; seg++) {
+ log.info("Creating segment " + (seg + 1) + "/" + segmentCount);
+
+ // Randomize vector probability per segment
+ double vectorProbability =
+ 0.5 + (random().nextDouble() * 0.4); // 50-90% vectors per segment
+
+ for (int i = 0; i < docsPerSegment; i++) {
+ int docId = seg * docsPerSegment + i;
+ Document doc = new Document();
+ doc.add(new StringField("id", String.valueOf(docId), Field.Store.YES));
+ doc.add(new NumericDocValuesField("segment", seg));
+ doc.add(new NumericDocValuesField("position", i));
+
+ // Add vector based on segment's randomized probability
+ if (random().nextDouble() < vectorProbability) {
+ float[] vector = generateRandomVector(vectorDimension, random());
+ doc.add(new KnnFloatVectorField("vector", vector, VectorSimilarityFunction.COSINE));
+ }
+
+ writer.addDocument(doc);
+ }
+ writer.commit();
+ }
+
+ log.info("Created " + segmentCount + " segments with " + totalDocuments + " total documents");
+
+ // Force merge all segments
+ long startTime = System.currentTimeMillis();
+ writer.forceMerge(1);
+ long mergeTime = System.currentTimeMillis() - startTime;
+
+ log.info("Large scale merge completed in " + mergeTime + "ms");
+ }
+
+ // Verify the large merged index
+ try (DirectoryReader reader = DirectoryReader.open(directory)) {
+ assertEquals("Should have exactly one segment after merge", 1, reader.leaves().size());
+
+ LeafReader leafReader = reader.leaves().get(0).reader();
+ assertEquals("Total documents should match", totalDocuments, leafReader.maxDoc());
+
+ // Test vector search performance
+ var vectorValues = leafReader.getFloatVectorValues("vector");
+ int actualVectorCount = vectorValues != null ? vectorValues.size() : 0;
+
+ if (actualVectorCount > 0) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ float[] queryVector = generateRandomVector(vectorDimension, random());
+
+ // Randomize search parameters for large scale test
+ int searchK =
+ Math.min(20 + random().nextInt(31), Math.min(actualVectorCount, TOP_K_LIMIT)); // 20-50
+
+ long searchStart = System.currentTimeMillis();
+ KnnFloatVectorQuery vectorQuery = new KnnFloatVectorQuery("vector", queryVector, searchK);
+ TopDocs vectorResults = searcher.search(vectorQuery, searchK);
+ long searchTime = System.currentTimeMillis() - searchStart;
+
+ assertTrue("Should find vector results in large index", vectorResults.scoreDocs.length > 0);
+ log.info(
+ "Vector search in large index returned "
+ + vectorResults.scoreDocs.length
+ + " results out of "
+ + actualVectorCount
+ + " vectors in "
+ + searchTime
+ + "ms");
+ } else {
+ log.info("No vectors in large merged index - skipping vector search");
+ }
+
+ log.info("Large scale merge verification completed successfully");
+ }
+ }
+
+ /** Helper method to generate random vectors */
+ private float[] generateRandomVector(int dimension, Random random) {
+ float[] vector = new float[dimension];
+ for (int i = 0; i < dimension; i++) {
+ vector[i] = (float) random().nextGaussian();
+ }
+ // Normalize the vector
+ float norm = 0.0f;
+ for (float v : vector) {
+ norm += v * v;
+ }
+ norm = (float) Math.sqrt(norm);
+ if (norm > 0) {
+ for (int i = 0; i < dimension; i++) {
+ vector[i] /= norm;
+ }
+ }
+ return vector;
+ }
+
+ /** Helper method to generate random text strings for sorting */
+ private String generateRandomText(Random random, int length) {
+ StringBuilder sb = new StringBuilder(length);
+ String chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
+ for (int i = 0; i < length; i++) {
+ sb.append(chars.charAt(random().nextInt(chars.length())));
+ }
+ return sb.toString();
+ }
+}
diff --git a/src/test/java/com/nvidia/cuvs/lucene/TestUtils.java b/src/test/java/com/nvidia/cuvs/lucene/TestUtils.java
new file mode 100644
index 0000000..8bd8339
--- /dev/null
+++ b/src/test/java/com/nvidia/cuvs/lucene/TestUtils.java
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2025, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.cuvs.lucene;
+
+import java.util.Random;
+
+public class TestUtils {
+
+ public static float[][] generateDataset(Random random, int size, int dimensions) {
+ float[][] dataset = new float[size][dimensions];
+ for (int i = 0; i < size; i++) {
+ for (int j = 0; j < dimensions; j++) {
+ dataset[i][j] = random.nextFloat() * 100;
+ }
+ }
+ return dataset;
+ }
+
+ public static float[] generateRandomVector(int dimensions, Random random) {
+ float[] vector = new float[dimensions];
+ for (int i = 0; i < dimensions; i++) {
+ vector[i] = random.nextFloat() * 100;
+ }
+ return vector;
+ }
+
+ public static float[][] generateQueries(Random random, int dimensions, int numQueries) {
+ // Generate random query vectors
+ float[][] queries = new float[numQueries][dimensions];
+ for (int i = 0; i < numQueries; i++) {
+ for (int j = 0; j < dimensions; j++) {
+ queries[i][j] = random.nextFloat() * 100;
+ }
+ }
+ return queries;
+ }
+}