Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 118 additions & 98 deletions qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat;
import org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat;
import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat;
import org.elasticsearch.index.codec.vectors.es93.ES93FlatVectorFormat;
import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat;
import org.elasticsearch.index.codec.vectors.es93.ES93HnswScalarQuantizedVectorsFormat;
import org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat;
Expand All @@ -52,15 +53,16 @@

import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadInfo;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

Expand Down Expand Up @@ -89,21 +91,23 @@ public class KnnIndexTester {
static final String INDEX_DIR = "target/knn_index";

enum IndexType {
HNSW,
FLAT,
HNSW,
IVF,
GPU_HNSW
}

enum VectorEncoding {
BYTE(org.apache.lucene.index.VectorEncoding.BYTE),
FLOAT32(org.apache.lucene.index.VectorEncoding.FLOAT32),
BFLOAT16(org.apache.lucene.index.VectorEncoding.FLOAT32);
BYTE(org.apache.lucene.index.VectorEncoding.BYTE, DenseVectorFieldMapper.ElementType.BYTE),
FLOAT32(org.apache.lucene.index.VectorEncoding.FLOAT32, DenseVectorFieldMapper.ElementType.FLOAT),
BFLOAT16(org.apache.lucene.index.VectorEncoding.FLOAT32, DenseVectorFieldMapper.ElementType.BFLOAT16);

private final org.apache.lucene.index.VectorEncoding luceneEncoding;
private final DenseVectorFieldMapper.ElementType elementType;

VectorEncoding(org.apache.lucene.index.VectorEncoding luceneEncoding) {
VectorEncoding(org.apache.lucene.index.VectorEncoding luceneEncoding, DenseVectorFieldMapper.ElementType elementType) {
this.luceneEncoding = luceneEncoding;
this.elementType = elementType;
}

public org.apache.lucene.index.VectorEncoding luceneEncoding() {
Expand All @@ -120,109 +124,106 @@ enum MergePolicyType {

private static String formatIndexPath(TestConfiguration args) {
List<String> suffix = new ArrayList<>();
if (args.indexType() == IndexType.FLAT) {
suffix.add("flat");
} else if (args.indexType() == IndexType.GPU_HNSW) {
suffix.add("gpu_hnsw");
} else if (args.indexType() == IndexType.IVF) {
suffix.add("ivf");
suffix.add(Integer.toString(args.ivfClusterSize()));
suffix.add(
Integer.toString(
args.secondaryClusterSize() == -1
? ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER
: args.secondaryClusterSize()
)
);
suffix.add(Integer.toString(args.quantizeBits()));
} else {
suffix.add(Integer.toString(args.hnswM()));
suffix.add(Integer.toString(args.hnswEfConstruction()));
if (args.quantizeBits() < 32) {
switch (args.indexType()) {
case FLAT -> suffix.add("flat");
case GPU_HNSW -> suffix.add("gpu_hnsw");
case IVF -> {
suffix.add("ivf");
suffix.add(Integer.toString(args.ivfClusterSize()));
suffix.add(
Integer.toString(
args.secondaryClusterSize() == -1
? ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER
: args.secondaryClusterSize()
)
);
suffix.add(Integer.toString(args.quantizeBits()));
}
case HNSW -> {
suffix.add(Integer.toString(args.hnswM()));
suffix.add(Integer.toString(args.hnswEfConstruction()));
if (args.quantizeBits() != null) {
suffix.add(Integer.toString(args.quantizeBits()));
}
}
}
return INDEX_DIR + "/" + args.docVectors().get(0).getFileName() + "-" + String.join("-", suffix) + ".index";

return INDEX_DIR + "/" + args.docVectors().getFirst().getFileName() + "-" + String.join("-", suffix) + ".index";
}

static Codec createCodec(TestConfiguration args, @Nullable ExecutorService exec) {
final KnnVectorsFormat format;
int quantizeBits = args.quantizeBits();
DenseVectorFieldMapper.ElementType elementType = switch (args.vectorEncoding()) {
case BYTE -> DenseVectorFieldMapper.ElementType.BYTE;
case FLOAT32 -> DenseVectorFieldMapper.ElementType.FLOAT;
case BFLOAT16 -> DenseVectorFieldMapper.ElementType.BFLOAT16;
};
if (args.indexType() == IndexType.IVF) {
ESNextDiskBBQVectorsFormat.QuantEncoding encoding = switch (quantizeBits) {
case (1) -> ESNextDiskBBQVectorsFormat.QuantEncoding.ONE_BIT_4BIT_QUERY;
case (2) -> ESNextDiskBBQVectorsFormat.QuantEncoding.TWO_BIT_4BIT_QUERY;
case (4) -> ESNextDiskBBQVectorsFormat.QuantEncoding.FOUR_BIT_SYMMETRIC;
Integer quantizeBits = args.quantizeBits();
DenseVectorFieldMapper.ElementType elementType = args.vectorEncoding().elementType;

format = switch (args.indexType()) {
case IVF -> {
ESNextDiskBBQVectorsFormat.QuantEncoding encoding = switch (quantizeBits) {
case 1 -> ESNextDiskBBQVectorsFormat.QuantEncoding.ONE_BIT_4BIT_QUERY;
case 2 -> ESNextDiskBBQVectorsFormat.QuantEncoding.TWO_BIT_4BIT_QUERY;
case 4 -> ESNextDiskBBQVectorsFormat.QuantEncoding.FOUR_BIT_SYMMETRIC;
default -> throw new IllegalArgumentException(
"IVF index type only supports 1, 2 or 4 bits quantization, but got: " + quantizeBits
);
};
yield new ESNextDiskBBQVectorsFormat(
encoding,
args.ivfClusterSize(),
args.secondaryClusterSize() == -1
? ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER
: args.secondaryClusterSize(),
elementType,
args.onDiskRescore(),
exec,
exec != null ? args.numMergeWorkers() : 1,
args.doPrecondition(),
args.preconditioningBlockDims()
);
}
case GPU_HNSW -> switch (quantizeBits) {
case null -> new ES92GpuHnswVectorsFormat();
case 7 -> new ES92GpuHnswSQVectorsFormat();
default -> throw new IllegalArgumentException(
"IVF index type only supports 1, 2 or 4 bits quantization, but got: " + quantizeBits
"GPU HNSW index type only supports 7 bits quantization, but got: " + quantizeBits
);
};
format = new ESNextDiskBBQVectorsFormat(
encoding,
args.ivfClusterSize(),
args.secondaryClusterSize() == -1
? ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER
: args.secondaryClusterSize(),
elementType,
args.onDiskRescore(),
exec,
exec != null ? args.numMergeWorkers() : 1,
args.doPrecondition(),
args.preconditioningBlockDims()
);
} else if (args.indexType() == IndexType.GPU_HNSW) {
if (quantizeBits == 32) {
format = new ES92GpuHnswVectorsFormat();
} else if (quantizeBits == 7) {
format = new ES92GpuHnswSQVectorsFormat();
} else {
throw new IllegalArgumentException("GPU HNSW index type only supports 7 or 32 bits quantization, but got: " + quantizeBits);
}
} else {
if (quantizeBits == 1) {
if (args.indexType() == IndexType.FLAT) {
format = new ES93BinaryQuantizedVectorsFormat(elementType, false);
} else {
format = new ES93HnswBinaryQuantizedVectorsFormat(
args.hnswM(),
args.hnswEfConstruction(),
elementType,
false,
exec != null ? args.numMergeWorkers() : 1,
exec
);
}
} else if (quantizeBits < 32) {
if (args.indexType() == IndexType.FLAT) {
format = new ES93ScalarQuantizedVectorsFormat(elementType, null, quantizeBits, true, false);
} else {
format = new ES93HnswScalarQuantizedVectorsFormat(
args.hnswM(),
args.hnswEfConstruction(),
elementType,
null,
quantizeBits,
true,
false,
exec != null ? args.numMergeWorkers() : 1,
exec
);
}
} else {
format = new ES93HnswVectorsFormat(
case HNSW -> switch (quantizeBits) {
case null -> new ES93HnswVectorsFormat(
args.hnswM(),
args.hnswEfConstruction(),
elementType,
exec != null ? args.numMergeWorkers() : 1,
exec
);
}
}
case 1 -> new ES93HnswBinaryQuantizedVectorsFormat(
args.hnswM(),
args.hnswEfConstruction(),
elementType,
false,
exec != null ? args.numMergeWorkers() : 1,
exec
);
default -> new ES93HnswScalarQuantizedVectorsFormat(
args.hnswM(),
args.hnswEfConstruction(),
elementType,
null,
quantizeBits,
true,
false,
exec != null ? args.numMergeWorkers() : 1,
exec
);
};
case FLAT -> switch (quantizeBits) {
case null -> new ES93FlatVectorFormat(elementType);
case 1 -> new ES93BinaryQuantizedVectorsFormat(elementType, false);
default -> new ES93ScalarQuantizedVectorsFormat(elementType, null, quantizeBits, true, false);
};
};

logger.info("Using format {}", format.getName());

return new Lucene103Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
Expand Down Expand Up @@ -336,13 +337,13 @@ public static void main(String[] args) throws Exception {
FormattedResults formattedResults = new FormattedResults();

for (TestConfiguration testConfiguration : testConfigurationList) {
// check this here so IVF/GPUHNSW can guarantee quantizeBits is set properly
checkQuantizeBits(testConfiguration);
String indexPathName = formatIndexPath(testConfiguration);
String indexType = testConfiguration.indexType().name().toLowerCase(Locale.ROOT);
Results indexResults = new Results(indexPathName, indexType, testConfiguration.numDocs());
Results[] results = new Results[testConfiguration.numberOfSearchRuns()];
for (int i = 0; i < results.length; i++) {
results[i] = new Results(indexPathName, indexType, testConfiguration.numDocs());
}
Arrays.setAll(results, i -> new Results(indexPathName, indexType, testConfiguration.numDocs()));
logger.info("Running with Java: " + Runtime.version());
logger.info("Running KNN index tester with arguments: " + testConfiguration);
final ExecutorService exec;
Expand Down Expand Up @@ -409,6 +410,25 @@ public static void main(String[] args) throws Exception {
logger.info("Results: \n" + formattedResults);
}

private static void checkQuantizeBits(TestConfiguration args) {
switch (args.indexType()) {
case IVF:
if (args.quantizeBits() == null || !Set.of(1, 2, 4).contains(args.quantizeBits())) {
throw new IllegalArgumentException(
"IVF index type only supports 1, 2 or 4 bits quantization, but got: " + args.quantizeBits()
);
}
break;
case GPU_HNSW: {
if (args.quantizeBits() != null && args.quantizeBits() != 7) {
throw new IllegalArgumentException(
"GPU HNSW index type only supports 7 bits quantization, but got: " + args.quantizeBits()
);
}
}
}
}

private static MergePolicy getMergePolicy(TestConfiguration args) {
return switch (args.mergePolicy()) {
case null -> null;
Expand All @@ -419,11 +439,11 @@ private static MergePolicy getMergePolicy(TestConfiguration args) {
};
}

static void numSegments(Path indexPath, Results result) {
static void numSegments(Path indexPath, Results result) throws IOException {
try (FSDirectory dir = FSDirectory.open(indexPath); IndexReader reader = DirectoryReader.open(dir)) {
result.numSegments = reader.leaves().size();
} catch (IOException e) {
throw new UncheckedIOException("Failed to get segment count for index at " + indexPath, e);
throw new IOException("Failed to get segment count for index at " + indexPath, e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ record TestConfiguration(
boolean reindex,
boolean forceMerge,
VectorSimilarityFunction vectorSpace,
int quantizeBits,
Integer quantizeBits,
KnnIndexTester.VectorEncoding vectorEncoding,
int dimensions,
KnnIndexTester.MergePolicyType mergePolicy,
Expand Down Expand Up @@ -128,7 +128,12 @@ static TestConfiguration fromXContent(XContentParser parser) throws IOException
PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD);
PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD);
PARSER.declareString(Builder::setVectorSpace, VECTOR_SPACE_FIELD);
PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD);
PARSER.declareField(
Builder::setQuantizeBits,
p -> p.currentToken() == XContentParser.Token.VALUE_NULL ? null : p.intValue(),
QUANTIZE_BITS_FIELD,
ObjectParser.ValueType.INT_OR_NULL
);
PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD);
PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD);
PARSER.declareFieldArray(
Expand Down Expand Up @@ -261,7 +266,7 @@ static class Builder implements ToXContentObject {
private boolean forceMerge = false;
private int forceMergeMaxNumSegments = 1;
private VectorSimilarityFunction vectorSpace = VectorSimilarityFunction.EUCLIDEAN;
private int quantizeBits = 8;
private Integer quantizeBits = null;
private KnnIndexTester.VectorEncoding vectorEncoding = KnnIndexTester.VectorEncoding.FLOAT32;
private int dimensions;
private List<Boolean> earlyTermination = List.of(Boolean.FALSE);
Expand Down Expand Up @@ -382,7 +387,7 @@ public Builder setVectorSpace(String vectorSpace) {
return this;
}

public Builder setQuantizeBits(int quantizeBits) {
public Builder setQuantizeBits(Integer quantizeBits) {
this.quantizeBits = quantizeBits;
return this;
}
Expand Down Expand Up @@ -563,7 +568,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(REINDEX_FIELD.getPreferredName(), reindex);
builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge);
builder.field(VECTOR_SPACE_FIELD.getPreferredName(), vectorSpace.name().toLowerCase(Locale.ROOT));
builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits);
if (quantizeBits != null) {
builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits);
}
builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT));
builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions);
builder.field(EARLY_TERMINATION_FIELD.getPreferredName(), earlyTermination);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
import java.io.IOException;
import java.util.Map;

/**
* A generic flat format that can use several different underlying vector storage formats.
* <p>
* This format is not meant to be used directly; it should be used as part of another vector format implementation.
*/
public class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat {

static final String NAME = "ES93GenericFlatVectorsFormat";
Expand Down