diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index 9768a099c9160..77cf12bdf3459 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -38,6 +38,7 @@ import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; import org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93FlatVectorFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat; @@ -52,15 +53,16 @@ import java.io.IOException; import java.io.InputStream; -import java.io.UncheckedIOException; import java.lang.management.ManagementFactory; import java.lang.management.ThreadInfo; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -89,21 +91,23 @@ public class KnnIndexTester { static final String INDEX_DIR = "target/knn_index"; enum IndexType { - HNSW, FLAT, + HNSW, IVF, GPU_HNSW } enum VectorEncoding { - BYTE(org.apache.lucene.index.VectorEncoding.BYTE), - FLOAT32(org.apache.lucene.index.VectorEncoding.FLOAT32), - BFLOAT16(org.apache.lucene.index.VectorEncoding.FLOAT32); + BYTE(org.apache.lucene.index.VectorEncoding.BYTE, DenseVectorFieldMapper.ElementType.BYTE), + FLOAT32(org.apache.lucene.index.VectorEncoding.FLOAT32, DenseVectorFieldMapper.ElementType.FLOAT), + BFLOAT16(org.apache.lucene.index.VectorEncoding.FLOAT32, DenseVectorFieldMapper.ElementType.BFLOAT16); private final org.apache.lucene.index.VectorEncoding luceneEncoding; + private final DenseVectorFieldMapper.ElementType elementType; - VectorEncoding(org.apache.lucene.index.VectorEncoding luceneEncoding) { + VectorEncoding(org.apache.lucene.index.VectorEncoding luceneEncoding, DenseVectorFieldMapper.ElementType elementType) { this.luceneEncoding = luceneEncoding; + this.elementType = elementType; } public org.apache.lucene.index.VectorEncoding luceneEncoding() { @@ -120,109 +124,106 @@ enum MergePolicyType { private static String formatIndexPath(TestConfiguration args) { List suffix = new ArrayList<>(); - if (args.indexType() == IndexType.FLAT) { - suffix.add("flat"); - } else if (args.indexType() == IndexType.GPU_HNSW) { - suffix.add("gpu_hnsw"); - } else if (args.indexType() == IndexType.IVF) { - suffix.add("ivf"); - suffix.add(Integer.toString(args.ivfClusterSize())); - suffix.add( - Integer.toString( - args.secondaryClusterSize() == -1 - ? ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER - : args.secondaryClusterSize() - ) - ); - suffix.add(Integer.toString(args.quantizeBits())); - } else { - suffix.add(Integer.toString(args.hnswM())); - suffix.add(Integer.toString(args.hnswEfConstruction())); - if (args.quantizeBits() < 32) { + switch (args.indexType()) { + case FLAT -> suffix.add("flat"); + case GPU_HNSW -> suffix.add("gpu_hnsw"); + case IVF -> { + suffix.add("ivf"); + suffix.add(Integer.toString(args.ivfClusterSize())); + suffix.add( + Integer.toString( + args.secondaryClusterSize() == -1 + ? ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER + : args.secondaryClusterSize() + ) + ); suffix.add(Integer.toString(args.quantizeBits())); } + case HNSW -> { + suffix.add(Integer.toString(args.hnswM())); + suffix.add(Integer.toString(args.hnswEfConstruction())); + if (args.quantizeBits() != null) { + suffix.add(Integer.toString(args.quantizeBits())); + } + } } - return INDEX_DIR + "/" + args.docVectors().get(0).getFileName() + "-" + String.join("-", suffix) + ".index"; + + return INDEX_DIR + "/" + args.docVectors().getFirst().getFileName() + "-" + String.join("-", suffix) + ".index"; } static Codec createCodec(TestConfiguration args, @Nullable ExecutorService exec) { final KnnVectorsFormat format; - int quantizeBits = args.quantizeBits(); - DenseVectorFieldMapper.ElementType elementType = switch (args.vectorEncoding()) { - case BYTE -> DenseVectorFieldMapper.ElementType.BYTE; - case FLOAT32 -> DenseVectorFieldMapper.ElementType.FLOAT; - case BFLOAT16 -> DenseVectorFieldMapper.ElementType.BFLOAT16; - }; - if (args.indexType() == IndexType.IVF) { - ESNextDiskBBQVectorsFormat.QuantEncoding encoding = switch (quantizeBits) { - case (1) -> ESNextDiskBBQVectorsFormat.QuantEncoding.ONE_BIT_4BIT_QUERY; - case (2) -> ESNextDiskBBQVectorsFormat.QuantEncoding.TWO_BIT_4BIT_QUERY; - case (4) -> ESNextDiskBBQVectorsFormat.QuantEncoding.FOUR_BIT_SYMMETRIC; + Integer quantizeBits = args.quantizeBits(); + DenseVectorFieldMapper.ElementType elementType = args.vectorEncoding().elementType; + + format = switch (args.indexType()) { + case IVF -> { + ESNextDiskBBQVectorsFormat.QuantEncoding encoding = switch (quantizeBits) { + case 1 -> ESNextDiskBBQVectorsFormat.QuantEncoding.ONE_BIT_4BIT_QUERY; + case 2 -> ESNextDiskBBQVectorsFormat.QuantEncoding.TWO_BIT_4BIT_QUERY; + case 4 -> ESNextDiskBBQVectorsFormat.QuantEncoding.FOUR_BIT_SYMMETRIC; + default -> throw new IllegalArgumentException( + "IVF index type only supports 1, 2 or 4 bits quantization, but got: " + quantizeBits + ); + }; + yield new ESNextDiskBBQVectorsFormat( + encoding, + args.ivfClusterSize(), + args.secondaryClusterSize() == -1 + ? ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER + : args.secondaryClusterSize(), + elementType, + args.onDiskRescore(), + exec, + exec != null ? args.numMergeWorkers() : 1, + args.doPrecondition(), + args.preconditioningBlockDims() + ); + } + case GPU_HNSW -> switch (quantizeBits) { + case null -> new ES92GpuHnswVectorsFormat(); + case 7 -> new ES92GpuHnswSQVectorsFormat(); default -> throw new IllegalArgumentException( - "IVF index type only supports 1, 2 or 4 bits quantization, but got: " + quantizeBits + "GPU HNSW index type only supports 7 bits quantization, but got: " + quantizeBits ); }; - format = new ESNextDiskBBQVectorsFormat( - encoding, - args.ivfClusterSize(), - args.secondaryClusterSize() == -1 - ? ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER - : args.secondaryClusterSize(), - elementType, - args.onDiskRescore(), - exec, - exec != null ? args.numMergeWorkers() : 1, - args.doPrecondition(), - args.preconditioningBlockDims() - ); - } else if (args.indexType() == IndexType.GPU_HNSW) { - if (quantizeBits == 32) { - format = new ES92GpuHnswVectorsFormat(); - } else if (quantizeBits == 7) { - format = new ES92GpuHnswSQVectorsFormat(); - } else { - throw new IllegalArgumentException("GPU HNSW index type only supports 7 or 32 bits quantization, but got: " + quantizeBits); - } - } else { - if (quantizeBits == 1) { - if (args.indexType() == IndexType.FLAT) { - format = new ES93BinaryQuantizedVectorsFormat(elementType, false); - } else { - format = new ES93HnswBinaryQuantizedVectorsFormat( - args.hnswM(), - args.hnswEfConstruction(), - elementType, - false, - exec != null ? args.numMergeWorkers() : 1, - exec - ); - } - } else if (quantizeBits < 32) { - if (args.indexType() == IndexType.FLAT) { - format = new ES93ScalarQuantizedVectorsFormat(elementType, null, quantizeBits, true, false); - } else { - format = new ES93HnswScalarQuantizedVectorsFormat( - args.hnswM(), - args.hnswEfConstruction(), - elementType, - null, - quantizeBits, - true, - false, - exec != null ? args.numMergeWorkers() : 1, - exec - ); - } - } else { - format = new ES93HnswVectorsFormat( + case HNSW -> switch (quantizeBits) { + case null -> new ES93HnswVectorsFormat( args.hnswM(), args.hnswEfConstruction(), elementType, exec != null ? args.numMergeWorkers() : 1, exec ); - } - } + case 1 -> new ES93HnswBinaryQuantizedVectorsFormat( + args.hnswM(), + args.hnswEfConstruction(), + elementType, + false, + exec != null ? args.numMergeWorkers() : 1, + exec + ); + default -> new ES93HnswScalarQuantizedVectorsFormat( + args.hnswM(), + args.hnswEfConstruction(), + elementType, + null, + quantizeBits, + true, + false, + exec != null ? args.numMergeWorkers() : 1, + exec + ); + }; + case FLAT -> switch (quantizeBits) { + case null -> new ES93FlatVectorFormat(elementType); + case 1 -> new ES93BinaryQuantizedVectorsFormat(elementType, false); + default -> new ES93ScalarQuantizedVectorsFormat(elementType, null, quantizeBits, true, false); + }; + }; + + logger.info("Using format {}", format.getName()); + return new Lucene103Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { @@ -336,13 +337,13 @@ public static void main(String[] args) throws Exception { FormattedResults formattedResults = new FormattedResults(); for (TestConfiguration testConfiguration : testConfigurationList) { + // check this here so IVF/GPUHNSW can guarantee quantizeBits is set properly + checkQuantizeBits(testConfiguration); String indexPathName = formatIndexPath(testConfiguration); String indexType = testConfiguration.indexType().name().toLowerCase(Locale.ROOT); Results indexResults = new Results(indexPathName, indexType, testConfiguration.numDocs()); Results[] results = new Results[testConfiguration.numberOfSearchRuns()]; - for (int i = 0; i < results.length; i++) { - results[i] = new Results(indexPathName, indexType, testConfiguration.numDocs()); - } + Arrays.setAll(results, i -> new Results(indexPathName, indexType, testConfiguration.numDocs())); logger.info("Running with Java: " + Runtime.version()); logger.info("Running KNN index tester with arguments: " + testConfiguration); final ExecutorService exec; @@ -409,6 +410,25 @@ public static void main(String[] args) throws Exception { logger.info("Results: \n" + formattedResults); } + private static void checkQuantizeBits(TestConfiguration args) { + switch (args.indexType()) { + case IVF: + if (args.quantizeBits() == null || !Set.of(1, 2, 4).contains(args.quantizeBits())) { + throw new IllegalArgumentException( + "IVF index type only supports 1, 2 or 4 bits quantization, but got: " + args.quantizeBits() + ); + } + break; + case GPU_HNSW: { + if (args.quantizeBits() != null && args.quantizeBits() != 7) { + throw new IllegalArgumentException( + "GPU HNSW index type only supports 7 bits quantization, but got: " + args.quantizeBits() + ); + } + } + } + } + private static MergePolicy getMergePolicy(TestConfiguration args) { return switch (args.mergePolicy()) { case null -> null; @@ -419,11 +439,11 @@ private static MergePolicy getMergePolicy(TestConfiguration args) { }; } - static void numSegments(Path indexPath, Results result) { + static void numSegments(Path indexPath, Results result) throws IOException { try (FSDirectory dir = FSDirectory.open(indexPath); IndexReader reader = DirectoryReader.open(dir)) { result.numSegments = reader.leaves().size(); } catch (IOException e) { - throw new UncheckedIOException("Failed to get segment count for index at " + indexPath, e); + throw new IOException("Failed to get segment count for index at " + indexPath, e); } } diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java index 0d7ed6db21cc4..76519417b6bc5 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java @@ -44,7 +44,7 @@ record TestConfiguration( boolean reindex, boolean forceMerge, VectorSimilarityFunction vectorSpace, - int quantizeBits, + Integer quantizeBits, KnnIndexTester.VectorEncoding vectorEncoding, int dimensions, KnnIndexTester.MergePolicyType mergePolicy, @@ -128,7 +128,12 @@ static TestConfiguration fromXContent(XContentParser parser) throws IOException PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD); PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD); PARSER.declareString(Builder::setVectorSpace, VECTOR_SPACE_FIELD); - PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD); + PARSER.declareField( + Builder::setQuantizeBits, + p -> p.currentToken() == XContentParser.Token.VALUE_NULL ? null : p.intValue(), + QUANTIZE_BITS_FIELD, + ObjectParser.ValueType.INT_OR_NULL + ); PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD); PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD); PARSER.declareFieldArray( @@ -261,7 +266,7 @@ static class Builder implements ToXContentObject { private boolean forceMerge = false; private int forceMergeMaxNumSegments = 1; private VectorSimilarityFunction vectorSpace = VectorSimilarityFunction.EUCLIDEAN; - private int quantizeBits = 8; + private Integer quantizeBits = null; private KnnIndexTester.VectorEncoding vectorEncoding = KnnIndexTester.VectorEncoding.FLOAT32; private int dimensions; private List earlyTermination = List.of(Boolean.FALSE); @@ -382,7 +387,7 @@ public Builder setVectorSpace(String vectorSpace) { return this; } - public Builder setQuantizeBits(int quantizeBits) { + public Builder setQuantizeBits(Integer quantizeBits) { this.quantizeBits = quantizeBits; return this; } @@ -563,7 +568,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(REINDEX_FIELD.getPreferredName(), reindex); builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge); builder.field(VECTOR_SPACE_FIELD.getPreferredName(), vectorSpace.name().toLowerCase(Locale.ROOT)); - builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits); + if (quantizeBits != null) { + builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits); + } builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT)); builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); builder.field(EARLY_TERMINATION_FIELD.getPreferredName(), earlyTermination); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java index 2335fb2d676f7..95805f4e71840 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java @@ -22,6 +22,11 @@ import java.io.IOException; import java.util.Map; +/** + * A generic flat format that can use several different underlying vector storage formats. + *

+ * This format is not meant to be used directly; it should be used as part of another vector format implementation. + */ public class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat { static final String NAME = "ES93GenericFlatVectorsFormat";