diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/index/codec/tsdb/TSDBDocValuesMergeBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/index/codec/tsdb/TSDBDocValuesMergeBenchmark.java index e19778a783989..99d2acc6607c0 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/index/codec/tsdb/TSDBDocValuesMergeBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/index/codec/tsdb/TSDBDocValuesMergeBenchmark.java @@ -26,7 +26,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.benchmark.Utils; import org.elasticsearch.cluster.metadata.DataStream; -import org.elasticsearch.index.codec.Elasticsearch92Lucene103Codec; +import org.elasticsearch.index.codec.Elasticsearch93Lucene104Codec; import org.elasticsearch.index.codec.tsdb.BinaryDVCompressionMode; import org.elasticsearch.index.codec.tsdb.es819.ES819Version3TSDBDocValuesFormat; import org.openjdk.jmh.annotations.Benchmark; @@ -266,7 +266,7 @@ private static IndexWriterConfig createIndexWriterConfig(boolean optimizedMergeE true, NUMERIC_LARGE_BLOCK_SHIFT ); - config.setCodec(new Elasticsearch92Lucene103Codec() { + config.setCodec(new Elasticsearch93Lucene104Codec() { @Override public DocValuesFormat getDocValuesFormatForField(String field) { return docValuesFormat; diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/BenchmarkUtils.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/BenchmarkUtils.java index 4efb8d6fbb2cc..91cd12e2d87b8 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/BenchmarkUtils.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/BenchmarkUtils.java @@ -9,11 +9,11 @@ package org.elasticsearch.benchmark.vector.scorer; +import org.apache.lucene.backward_codecs.lucene99.OffHeapQuantizedByteVectorValues; import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; -import org.apache.lucene.codecs.lucene99.OffHeapQuantizedByteVectorValues; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; diff --git a/build-tools-internal/version.properties b/build-tools-internal/version.properties index 237904b0444db..8810e7216359c 100644 --- a/build-tools-internal/version.properties +++ b/build-tools-internal/version.properties @@ -1,5 +1,5 @@ elasticsearch = 9.4.0 -lucene = 10.3.2 +lucene = 10.4.0 bundled_jdk_vendor = openjdk bundled_jdk = 25.0.2+10@b1e0dfa218384cb9959bdcb897162d4e diff --git a/docs/Versions.asciidoc b/docs/Versions.asciidoc index 4e28c3eee45fc..738069856228f 100644 --- a/docs/Versions.asciidoc +++ b/docs/Versions.asciidoc @@ -1,8 +1,8 @@ include::{docs-root}/shared/versions/stack/{source_branch}.asciidoc[] -:lucene_version: 10.3.2 -:lucene_version_path: 10_3_2 +:lucene_version: 10.4.0 +:lucene_version_path: 10_4_0 :jdk: 11.0.2 :jdk_major: 11 :build_type: tar diff --git a/docs/changelog/141074.yaml b/docs/changelog/141074.yaml new file mode 100644 index 0000000000000..a99427a5d51d7 --- /dev/null +++ b/docs/changelog/141074.yaml @@ -0,0 +1,5 @@ +pr: 141074 +summary: Add flat_index_threshold parameter for hnsw dense_vector fields +area: Vector Search +type: enhancement +issues: [] diff --git a/docs/changelog/141882.yaml b/docs/changelog/141882.yaml new file mode 100644 index 0000000000000..ef9ff11c227ec --- /dev/null +++ b/docs/changelog/141882.yaml @@ -0,0 +1,5 @@ +area: Search +issues: [] +pr: 141882 +summary: Upgrade Elasticsearch to Apache Lucene 10.4 +type: upgrade diff --git a/docs/reference/elasticsearch/mapping-reference/dense-vector.md b/docs/reference/elasticsearch/mapping-reference/dense-vector.md index 8b49013cd6c61..0ad6eebb5da29 100644 --- a/docs/reference/elasticsearch/mapping-reference/dense-vector.md +++ b/docs/reference/elasticsearch/mapping-reference/dense-vector.md @@ -344,7 +344,7 @@ This configuration is appropriate when full source fidelity is required, such as ## Automatically quantize vectors for kNN search [dense-vector-quantization] The `dense_vector` field type supports quantization to reduce the memory footprint required when [searching](docs-content://solutions/search/vector/knn.md#approximate-knn) `float` vectors. The supported vector quantization strategies for `dense_vector` kNN indexing are: -- [`int8`](#dense-vector-quantization-int8) +- [`int8`](#dense-vector-quantization-int8) - [`int4`](#dense-vector-quantization-int4) - [`bbq`](#dense-vector-quantization-bbq), available as: - [`bbq_hnsw`](/reference/elasticsearch/mapping-reference/bbq.md#bbq-hnsw) @@ -573,9 +573,6 @@ $$$dense-vector-index-options$$$ `ef_construction` : (Optional, integer) The number of candidates to track while assembling the list of nearest neighbors for each new node. Defaults to `100`. Only applicable to `hnsw`, `int8_hnsw`, `int4_hnsw` and `bbq_hnsw` index types. -`confidence_interval` -: (Optional, float) Only applicable to `int8_hnsw`, `int4_hnsw`, `int8_flat`, and `int4_flat` index types. The confidence interval to use when quantizing the vectors. Can be any value between and including `0.90` and `1.0` or exactly `0`. When the value is `0`, this indicates that dynamic quantiles should be calculated for optimized quantization. When between `0.90` and `1.0`, this value restricts the values used when calculating the quantization thresholds. For example, a value of `0.95` will only use the middle 95% of the values when calculating the quantization thresholds (e.g. the highest and lowest 2.5% of values will be ignored). Defaults to `1/(dims + 1)` for `int8` quantized vectors and `0` for `int4` for dynamic quantile calculation. - `default_visit_percentage` {applies_to}`stack: ga 9.2` : (Optional, integer) Only applicable to `bbq_disk`. Must be between 0 and 100. 0 will default to using `num_candidates` for calculating the percent visited. Increasing `default_visit_percentage` tends to improve the accuracy of the final results. Defaults to ~1% per shard for every 1 million vectors. @@ -734,7 +731,7 @@ flat --> int8_flat --> int4_flat --> hnsw --> int8_hnsw --> int4_hnsw ::: :::: -For updating all HNSW types (`hnsw`, `int8_hnsw`, `int4_hnsw`, `bbq_hnsw`) the number of connections `m` must either stay the same or increase. For the scalar quantized formats `int8_flat`, `int4_flat`, `int8_hnsw` and `int4_hnsw` the `confidence_interval` must always be consistent (once defined, it cannot change). +For updating all HNSW types (`hnsw`, `int8_hnsw`, `int4_hnsw`, `bbq_hnsw`) the number of connections `m` must either stay the same or increase. Updating `type` in `index_options` will fail in all other scenarios. diff --git a/docs/reference/elasticsearch/mapping-reference/semantic-text-setup-configuration.md b/docs/reference/elasticsearch/mapping-reference/semantic-text-setup-configuration.md index ae5ae2ef5c928..421fd487fc4aa 100644 --- a/docs/reference/elasticsearch/mapping-reference/semantic-text-setup-configuration.md +++ b/docs/reference/elasticsearch/mapping-reference/semantic-text-setup-configuration.md @@ -13,14 +13,14 @@ This page provides instructions for setting up and configuring `semantic_text` f ## Configure {{infer}} endpoints [configure-inference-endpoints] -You can configure {{infer}} endpoints for `semantic_text` fields in the following ways: +You can configure {{infer}} endpoints for `semantic_text` fields in the following ways: - [Use ELSER on EIS](#using-elser-on-eis) - [Use default and preconfigured endpoints](#default-and-preconfigured-endpoints) - [Use a custom {{infer}} endpoint](#using-custom-endpoint) :::{note} -If you use a [custom {{infer}} endpoint](#using-custom-endpoint) through your ML node and not through Elastic {{infer-cap}} Service (EIS), the recommended method is to [use dedicated endpoints for ingestion and search](#dedicated-endpoints-for-ingestion-and-search). +If you use a [custom {{infer}} endpoint](#using-custom-endpoint) through your ML node and not through Elastic {{infer-cap}} Service (EIS), the recommended method is to [use dedicated endpoints for ingestion and search](#dedicated-endpoints-for-ingestion-and-search). {applies_to}`stack: ga 9.1.0` If you use EIS, you don't have to set up dedicated endpoints. ::: @@ -195,7 +195,7 @@ PUT my-index-000002 ### Use dedicated endpoints for ingestion and search [dedicated-endpoints-for-ingestion-and-search] -If you use a [custom {{infer}} endpoint](#using-custom-endpoint) through your ML node and not through Elastic {{infer-cap}} Service, the recommended way to use `semantic_text` is by having dedicated {{infer}} endpoints for ingestion and search. +If you use a [custom {{infer}} endpoint](#using-custom-endpoint) through your ML node and not through Elastic {{infer-cap}} Service, the recommended way to use `semantic_text` is by having dedicated {{infer}} endpoints for ingestion and search. This ensures that search speed remains unaffected by ingestion workloads, and vice versa. After creating dedicated {{infer}} endpoints for both, you can reference them using the `inference_id` and `search_inference_id` parameters when setting up the index mapping for an index that uses the `semantic_text` field. @@ -232,7 +232,7 @@ PUT semantic-embeddings "mappings": { "properties": { "content": { - "type": "semantic_text", + "type": "semantic_text", "index_options": { "sparse_vector": { "prune": true, <1> @@ -271,8 +271,7 @@ PUT semantic-embeddings "dense_vector": { "type": "int8_hnsw", <1> "m": 15, <2> - "ef_construction": 90, <3> - "confidence_interval": 0.95 <4> + "ef_construction": 90 <3> } } } @@ -283,5 +282,4 @@ PUT semantic-embeddings 1. (Optional) Selects the `int8_hnsw` vector quantization strategy. Learn about [default quantization types](/reference/elasticsearch/mapping-reference/dense-vector.md#default-quantization-types). 2. (Optional) Sets `m` to 15 to control how many neighbors each node connects to in the HNSW graph. Default is `16`. 3. (Optional) Sets `ef_construction` to 90 to control how many candidate neighbors are considered during graph construction. Default is `100`. -4. (Optional) Sets `confidence_interval` to 0.95 to limit the value range used during quantization and balance accuracy with memory efficiency. diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index 60ae3be25a219..9f2ae4d1f0c89 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -3705,129 +3705,129 @@ - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + diff --git a/libs/gpu-codec/src/main/java/org/elasticsearch/gpu/codec/ES92GpuHnswVectorsWriter.java b/libs/gpu-codec/src/main/java/org/elasticsearch/gpu/codec/ES92GpuHnswVectorsWriter.java index 3e930024faae4..9626a6e463b93 100644 --- a/libs/gpu-codec/src/main/java/org/elasticsearch/gpu/codec/ES92GpuHnswVectorsWriter.java +++ b/libs/gpu-codec/src/main/java/org/elasticsearch/gpu/codec/ES92GpuHnswVectorsWriter.java @@ -54,7 +54,6 @@ import java.util.Objects; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; -import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter.mergeAndRecalculateQuantiles; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.elasticsearch.gpu.codec.ES92GpuHnswVectorsFormat.LUCENE99_HNSW_META_CODEC_NAME; import static org.elasticsearch.gpu.codec.ES92GpuHnswVectorsFormat.LUCENE99_HNSW_META_EXTENSION; @@ -64,6 +63,7 @@ import static org.elasticsearch.gpu.codec.ES92GpuHnswVectorsFormat.MIN_NUM_VECTORS_FOR_GPU_BUILD; import static org.elasticsearch.gpu.codec.MemorySegmentUtils.getContiguousMemorySegment; import static org.elasticsearch.gpu.codec.MemorySegmentUtils.getContiguousPackedMemorySegment; +import static org.elasticsearch.index.codec.vectors.Lucene99ScalarQuantizedVectorsWriter.mergeAndRecalculateQuantiles; /** * Writer that builds an Nvidia Carga Graph on GPU and then writes it into the Lucene99 HNSW format, @@ -530,7 +530,7 @@ public int neighborCount() { @Override public NodesIterator getNodesOnLevel(int level) { - return new ArrayNodesIterator(size()); + return new DenseNodesIterator(size()); } }; } diff --git a/libs/gpu-codec/src/test/java/org/elasticsearch/gpu/codec/ES92GpuHnswSQVectorsFormatTests.java b/libs/gpu-codec/src/test/java/org/elasticsearch/gpu/codec/ES92GpuHnswSQVectorsFormatTests.java index 947074b870181..a0f66f69ce13d 100644 --- a/libs/gpu-codec/src/test/java/org/elasticsearch/gpu/codec/ES92GpuHnswSQVectorsFormatTests.java +++ b/libs/gpu-codec/src/test/java/org/elasticsearch/gpu/codec/ES92GpuHnswSQVectorsFormatTests.java @@ -94,4 +94,9 @@ public void testMergingWithDifferentByteKnnFields() { public void testMismatchedFields() { // No bytes support } + + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } } diff --git a/libs/gpu-codec/src/test/java/org/elasticsearch/gpu/codec/ES92GpuHnswVectorsFormatTests.java b/libs/gpu-codec/src/test/java/org/elasticsearch/gpu/codec/ES92GpuHnswVectorsFormatTests.java index f91bba076158c..d21d99944a768 100644 --- a/libs/gpu-codec/src/test/java/org/elasticsearch/gpu/codec/ES92GpuHnswVectorsFormatTests.java +++ b/libs/gpu-codec/src/test/java/org/elasticsearch/gpu/codec/ES92GpuHnswVectorsFormatTests.java @@ -87,4 +87,8 @@ public void testMismatchedFields() throws Exception { // No bytes support } + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index 1a4b6d6ed4ebf..68b52874c9cb1 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -234,7 +234,8 @@ public static float calculateOSQLoss( float lambda, int[] quantize ) { - assert upperInterval >= lowerInterval; + assert upperInterval >= lowerInterval + : "upperInterval must be greater than or equal to lowerInterval, but was: " + upperInterval + " < " + lowerInterval; float step = ((upperInterval - lowerInterval) / (points - 1.0F)); float invStep = 1f / step; return IMPL.calculateOSQLoss(target, lowerInterval, upperInterval, step, invStep, norm2, lambda, quantize); diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java index 3ac29daf3889c..1a848eb659084 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java @@ -111,4 +111,37 @@ Optional getInt7SQVectorScorerSupplier( * @return an optional containing the vector scorer, or empty */ Optional getInt7SQVectorScorer(VectorSimilarityFunction sim, QuantizedByteVectorValues values, float[] queryVector); + + /** + * Returns an optional containing an int7 optimal scalar quantized vector score supplier + * for the given parameters, or an empty optional if a scorer is not supported. + * + * @param similarityType the similarity type + * @param input the index input containing the vector data + * @param values the random access vector values + * @return an optional containing the vector scorer supplier, or empty + */ + Optional getInt7uOSQVectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values + ); + + /** + * Returns an optional containing an int7 optimal scalar quantized vector scorer for + * the given parameters, or an empty optional if a scorer is not supported. + * + * @param sim the similarity type + * @param values the random access vector values + * @return an optional containing the vector scorer, or empty + */ + Optional getInt7uOSQVectorScorer( + VectorSimilarityFunction sim, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ); } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java index c285acf8f9c2b..4a237c8ffbae5 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java @@ -75,4 +75,26 @@ public Optional getInt7SQVectorScorer( ) { throw new UnsupportedOperationException("should not reach here"); } + + @Override + public Optional getInt7uOSQVectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values + ) { + throw new UnsupportedOperationException("should not reach here"); + } + + @Override + public Optional getInt7uOSQVectorScorer( + VectorSimilarityFunction sim, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + throw new UnsupportedOperationException("should not reach here"); + } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java index 82b1a1659871b..6450811ff82d1 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java @@ -25,6 +25,8 @@ import org.elasticsearch.simdvec.internal.FloatVectorScorerSupplier; import org.elasticsearch.simdvec.internal.Int7SQVectorScorer; import org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier; +import org.elasticsearch.simdvec.internal.Int7uOSQVectorScorer; +import org.elasticsearch.simdvec.internal.Int7uOSQVectorScorerSupplier; import java.util.Optional; @@ -120,6 +122,46 @@ public Optional getInt7SQVectorScorer( return Int7SQVectorScorer.create(sim, values, queryVector); } + @Override + public Optional getInt7uOSQVectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values + ) { + input = FilterIndexInput.unwrapOnlyTest(input); + input = MemorySegmentAccessInputAccess.unwrap(input); + if (input instanceof MemorySegmentAccessInput msInput) { + checkInvariants(values.size(), values.dimension(), input); + return switch (similarityType) { + case COSINE, DOT_PRODUCT -> Optional.of(new Int7uOSQVectorScorerSupplier.DotProductSupplier(msInput, values)); + case EUCLIDEAN -> Optional.of(new Int7uOSQVectorScorerSupplier.EuclideanSupplier(msInput, values)); + case MAXIMUM_INNER_PRODUCT -> Optional.of(new Int7uOSQVectorScorerSupplier.MaxInnerProductSupplier(msInput, values)); + }; + } + return Optional.empty(); + } + + @Override + public Optional getInt7uOSQVectorScorer( + VectorSimilarityFunction sim, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + return Int7uOSQVectorScorer.create( + sim, + values, + quantizedQuery, + lowerInterval, + upperInterval, + additionalCorrection, + quantizedComponentSum + ); + } + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { if (input.length() < (long) vectorByteLength * maxOrd) { throw new IllegalArgumentException("input length is less than expected vector data"); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/ByteVectorScorerSupplier.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/ByteVectorScorerSupplier.java index e417b33042aed..ad6792d4a4292 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/ByteVectorScorerSupplier.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/ByteVectorScorerSupplier.java @@ -50,41 +50,51 @@ protected final void checkOrdinal(int ord) { } } - final void bulkScoreFromOrds(int firstOrd, int[] ordinals, float[] scores, int numNodes) throws IOException { + final float bulkScoreFromOrds(int firstOrd, int[] ordinals, float[] scores, int numNodes) throws IOException { MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); if (vectorsSeg == null) { // we might be able to get segments for individual vectors, so try separately - scoreSeparately(firstOrd, ordinals, scores, numNodes); + return scoreSeparately(firstOrd, ordinals, scores, numNodes); } else { final int vectorPitch = dims; if (SUPPORTS_HEAP_SEGMENTS) { var ordinalsSeg = MemorySegment.ofArray(ordinals); var scoresSeg = MemorySegment.ofArray(scores); - bulkScoreFromSegment(vectorsSeg, dims, vectorPitch, firstOrd, ordinalsSeg, scoresSeg, numNodes); + return bulkScoreFromSegment(vectorsSeg, dims, vectorPitch, firstOrd, ordinalsSeg, scoresSeg, numNodes); } else { try (var arena = Arena.ofConfined()) { var ordinalsMemorySegment = arena.allocate((long) numNodes * Integer.BYTES, 32); var scoresMemorySegment = arena.allocate((long) numNodes * Float.BYTES, 32); MemorySegment.copy(ordinals, 0, ordinalsMemorySegment, ValueLayout.JAVA_INT, 0, numNodes); - bulkScoreFromSegment(vectorsSeg, dims, vectorPitch, firstOrd, ordinalsMemorySegment, scoresMemorySegment, numNodes); + float max = bulkScoreFromSegment( + vectorsSeg, + dims, + vectorPitch, + firstOrd, + ordinalsMemorySegment, + scoresMemorySegment, + numNodes + ); MemorySegment.copy(scoresMemorySegment, ValueLayout.JAVA_FLOAT, 0, scores, 0, numNodes); + return max; } } } } - private void scoreSeparately(int firstOrd, int[] ordinals, float[] scores, int numNodes) throws IOException { + private float scoreSeparately(int firstOrd, int[] ordinals, float[] scores, int numNodes) throws IOException { long firstByteOffset = (long) firstOrd * dims; byte[] firstVector = null; - + float max = Float.NEGATIVE_INFINITY; MemorySegment firstSeg = input.segmentSliceOrNull(firstByteOffset, dims); if (firstSeg == null) { firstVector = values.vectorValue(firstOrd).clone(); for (int i = 0; i < numNodes; i++) { scores[i] = fallbackScorer.compare(firstVector, values.vectorValue(ordinals[i])); + max = Math.max(max, scores[i]); } } else { for (int i = 0; i < numNodes; i++) { @@ -98,8 +108,10 @@ private void scoreSeparately(int firstOrd, int[] ordinals, float[] scores, int n } else { scores[i] = scoreFromSegments(firstSeg, secondSeg); } + max = Math.max(max, scores[i]); } } + return max; } final float scoreFromOrds(int firstOrd, int secondOrd) throws IOException { @@ -121,7 +133,7 @@ final float scoreFromOrds(int firstOrd, int secondOrd) throws IOException { abstract float scoreFromSegments(MemorySegment a, MemorySegment b); - abstract void bulkScoreFromSegment( + abstract float bulkScoreFromSegment( MemorySegment vectors, int vectorLength, int vectorPitch, @@ -149,8 +161,8 @@ public float score(int node) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { - bulkScoreFromOrds(ord, nodes, scores, numNodes); + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + return bulkScoreFromOrds(ord, nodes, scores, numNodes); } @Override @@ -177,7 +189,7 @@ float scoreFromSegments(MemorySegment a, MemorySegment b) { } @Override - void bulkScoreFromSegment( + float bulkScoreFromSegment( MemorySegment vectors, int vectorLength, int vectorPitch, @@ -190,10 +202,14 @@ void bulkScoreFromSegment( var firstVector = vectors.asSlice(firstByteOffset, vectorPitch); Similarities.cosineI8BulkWithOffsets(vectors, firstVector, dims, vectorPitch, ordinals, numNodes, scores); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { float squareDistance = scores.getAtIndex(ValueLayout.JAVA_FLOAT, i); - scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalize(squareDistance)); + float normalized = normalize(squareDistance); + scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalized); + max = Math.max(max, normalized); } + return max; } @Override @@ -214,7 +230,7 @@ float scoreFromSegments(MemorySegment a, MemorySegment b) { } @Override - void bulkScoreFromSegment( + float bulkScoreFromSegment( MemorySegment vectors, int vectorLength, int vectorPitch, @@ -227,10 +243,14 @@ void bulkScoreFromSegment( var firstVector = vectors.asSlice(firstByteOffset, vectorPitch); Similarities.squareDistanceI8BulkWithOffsets(vectors, firstVector, dims, vectorPitch, ordinals, numNodes, scores); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { float squareDistance = scores.getAtIndex(ValueLayout.JAVA_FLOAT, i); - scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, VectorUtil.normalizeDistanceToUnitInterval(squareDistance)); + float normalized = VectorUtil.normalizeDistanceToUnitInterval(squareDistance); + scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalized); + max = Math.max(max, normalized); } + return max; } @Override @@ -257,7 +277,7 @@ float scoreFromSegments(MemorySegment a, MemorySegment b) { } @Override - void bulkScoreFromSegment( + float bulkScoreFromSegment( MemorySegment vectors, int vectorLength, int vectorPitch, @@ -270,10 +290,14 @@ void bulkScoreFromSegment( var firstVector = vectors.asSlice(firstByteOffset, vectorPitch); Similarities.dotProductI8BulkWithOffsets(vectors, firstVector, dims, vectorPitch, ordinals, numNodes, scores); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { float dotProduct = scores.getAtIndex(ValueLayout.JAVA_FLOAT, i); - scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalize(dotProduct)); + float normalized = normalize(dotProduct); + scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalized); + max = Math.max(max, normalized); } + return max; } @Override @@ -294,7 +318,7 @@ float scoreFromSegments(MemorySegment a, MemorySegment b) { } @Override - void bulkScoreFromSegment( + float bulkScoreFromSegment( MemorySegment vectors, int vectorLength, int vectorPitch, @@ -307,10 +331,14 @@ void bulkScoreFromSegment( var firstVector = vectors.asSlice(firstByteOffset, vectorPitch); Similarities.dotProductI8BulkWithOffsets(vectors, firstVector, dims, vectorPitch, ordinals, numNodes, scores); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { float dotProduct = scores.getAtIndex(ValueLayout.JAVA_FLOAT, i); - scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, VectorUtil.scaleMaxInnerProductScore(dotProduct)); + float normalized = VectorUtil.scaleMaxInnerProductScore(dotProduct); + scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalized); + max = Math.max(max, normalized); } + return max; } @Override diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/FloatVectorScorerSupplier.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/FloatVectorScorerSupplier.java index d4afe960b606c..82293cd0d92f9 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/FloatVectorScorerSupplier.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/FloatVectorScorerSupplier.java @@ -49,25 +49,25 @@ protected final void checkOrdinal(int ord) { } } - final void bulkScoreFromOrds(int firstOrd, int[] ordinals, float[] scores, int numNodes) throws IOException { + final float bulkScoreFromOrds(int firstOrd, int[] ordinals, float[] scores, int numNodes) throws IOException { MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); if (vectorsSeg == null) { // we might be able to get segments for individual vectors, so try separately - scoreSeparately(firstOrd, ordinals, scores, numNodes); + return scoreSeparately(firstOrd, ordinals, scores, numNodes); } else { final int vectorLength = dims * Float.BYTES; final int vectorPitch = vectorLength; if (SUPPORTS_HEAP_SEGMENTS) { var ordinalsSeg = MemorySegment.ofArray(ordinals); var scoresSeg = MemorySegment.ofArray(scores); - bulkScoreFromSegment(vectorsSeg, vectorLength, vectorPitch, firstOrd, ordinalsSeg, scoresSeg, numNodes); + return bulkScoreFromSegment(vectorsSeg, vectorLength, vectorPitch, firstOrd, ordinalsSeg, scoresSeg, numNodes); } else { try (var arena = Arena.ofConfined()) { var ordinalsMemorySegment = arena.allocate((long) numNodes * Integer.BYTES, 32); var scoresMemorySegment = arena.allocate((long) numNodes * Float.BYTES, 32); MemorySegment.copy(ordinals, 0, ordinalsMemorySegment, ValueLayout.JAVA_INT, 0, numNodes); - bulkScoreFromSegment( + float max = bulkScoreFromSegment( vectorsSeg, vectorLength, vectorPitch, @@ -78,21 +78,24 @@ final void bulkScoreFromOrds(int firstOrd, int[] ordinals, float[] scores, int n ); MemorySegment.copy(scoresMemorySegment, ValueLayout.JAVA_FLOAT, 0, scores, 0, numNodes); + return max; } } } } - private void scoreSeparately(int firstOrd, int[] ordinals, float[] scores, int numNodes) throws IOException { + private float scoreSeparately(int firstOrd, int[] ordinals, float[] scores, int numNodes) throws IOException { final int length = dims * Float.BYTES; long firstByteOffset = (long) firstOrd * length; float[] firstVector = null; + float max = Float.NEGATIVE_INFINITY; MemorySegment firstSeg = input.segmentSliceOrNull(firstByteOffset, length); if (firstSeg == null) { firstVector = values.vectorValue(firstOrd).clone(); for (int i = 0; i < numNodes; i++) { scores[i] = fallbackScorer.compare(firstVector, values.vectorValue(ordinals[i])); + max = Math.max(max, scores[i]); } } else { for (int i = 0; i < numNodes; i++) { @@ -106,8 +109,10 @@ private void scoreSeparately(int firstOrd, int[] ordinals, float[] scores, int n } else { scores[i] = scoreFromSegments(firstSeg, secondSeg); } + max = Math.max(max, scores[i]); } } + return max; } final float scoreFromOrds(int firstOrd, int secondOrd) throws IOException { @@ -130,7 +135,7 @@ final float scoreFromOrds(int firstOrd, int secondOrd) throws IOException { abstract float scoreFromSegments(MemorySegment a, MemorySegment b); - abstract void bulkScoreFromSegment( + abstract float bulkScoreFromSegment( MemorySegment vectors, int vectorLength, int vectorPitch, @@ -158,8 +163,8 @@ public float score(int node) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { - bulkScoreFromOrds(ord, nodes, scores, numNodes); + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + return bulkScoreFromOrds(ord, nodes, scores, numNodes); } @Override @@ -182,7 +187,7 @@ float scoreFromSegments(MemorySegment a, MemorySegment b) { } @Override - protected void bulkScoreFromSegment( + protected float bulkScoreFromSegment( MemorySegment vectors, int vectorLength, int vectorPitch, @@ -195,10 +200,14 @@ protected void bulkScoreFromSegment( var firstVector = vectors.asSlice(firstByteOffset, vectorPitch); Similarities.squareDistanceF32BulkWithOffsets(vectors, firstVector, dims, vectorPitch, ordinals, numNodes, scores); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { float squareDistance = scores.getAtIndex(ValueLayout.JAVA_FLOAT, i); - scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, VectorUtil.normalizeDistanceToUnitInterval(squareDistance)); + float normalizedScore = VectorUtil.normalizeDistanceToUnitInterval(squareDistance); + scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalizedScore); + max = Math.max(max, normalizedScore); } + return max; } @Override @@ -219,7 +228,7 @@ float scoreFromSegments(MemorySegment a, MemorySegment b) { } @Override - protected void bulkScoreFromSegment( + protected float bulkScoreFromSegment( MemorySegment vectors, int vectorLength, int vectorPitch, @@ -232,10 +241,14 @@ protected void bulkScoreFromSegment( var firstVector = vectors.asSlice(firstByteOffset, vectorPitch); Similarities.dotProductF32BulkWithOffsets(vectors, firstVector, dims, vectorPitch, ordinals, numNodes, scores); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { float dotProduct = scores.getAtIndex(ValueLayout.JAVA_FLOAT, i); - scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, VectorUtil.normalizeToUnitInterval(dotProduct)); + float normalizedScore = VectorUtil.normalizeToUnitInterval(dotProduct); + scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalizedScore); + max = Math.max(max, normalizedScore); } + return max; } @Override @@ -256,7 +269,7 @@ float scoreFromSegments(MemorySegment a, MemorySegment b) { } @Override - protected void bulkScoreFromSegment( + protected float bulkScoreFromSegment( MemorySegment vectors, int vectorLength, int vectorPitch, @@ -269,10 +282,14 @@ protected void bulkScoreFromSegment( var firstVector = vectors.asSlice(firstByteOffset, vectorPitch); Similarities.dotProductF32BulkWithOffsets(vectors, firstVector, dims, vectorPitch, ordinals, numNodes, scores); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { float dotProduct = scores.getAtIndex(ValueLayout.JAVA_FLOAT, i); - scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, VectorUtil.scaleMaxInnerProductScore(dotProduct)); + float scaledScore = VectorUtil.scaleMaxInnerProductScore(dotProduct); + scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, scaledScore); + max = Math.max(max, scaledScore); } + return max; } @Override diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java index 32c75677e06b4..5c5f3e9810284 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java @@ -59,26 +59,29 @@ protected final void checkOrdinal(int ord) { } } - final void bulkScoreFromOrds(int firstOrd, int[] ordinals, float[] scores, int numNodes) throws IOException { + final float bulkScoreFromOrds(int firstOrd, int[] ordinals, float[] scores, int numNodes) throws IOException { MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); if (vectorsSeg == null) { + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; i++) { scores[i] = scoreFromOrds(firstOrd, ordinals[i]); + max = Math.max(max, scores[i]); } + return max; } else { final int vectorLength = dims; final int vectorPitch = vectorLength + Float.BYTES; if (SUPPORTS_HEAP_SEGMENTS) { var ordinalsSeg = MemorySegment.ofArray(ordinals); var scoresSeg = MemorySegment.ofArray(scores); - bulkScoreFromSegment(vectorsSeg, vectorLength, vectorPitch, firstOrd, ordinalsSeg, scoresSeg, numNodes); + return bulkScoreFromSegment(vectorsSeg, vectorLength, vectorPitch, firstOrd, ordinalsSeg, scoresSeg, numNodes); } else { try (var arena = Arena.ofConfined()) { var ordinalsMemorySegment = arena.allocate((long) numNodes * Integer.BYTES, 32); var scoresMemorySegment = arena.allocate((long) numNodes * Float.BYTES, 32); MemorySegment.copy(ordinals, 0, ordinalsMemorySegment, ValueLayout.JAVA_INT, 0, numNodes); - bulkScoreFromSegment( + float max = bulkScoreFromSegment( vectorsSeg, vectorLength, vectorPitch, @@ -89,6 +92,7 @@ final void bulkScoreFromOrds(int firstOrd, int[] ordinals, float[] scores, int n ); MemorySegment.copy(scoresMemorySegment, ValueLayout.JAVA_FLOAT, 0, scores, 0, numNodes); + return max; } } } @@ -116,7 +120,7 @@ final float scoreFromOrds(int firstOrd, int secondOrd) throws IOException { abstract float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float bOffset); - protected abstract void bulkScoreFromSegment( + protected abstract float bulkScoreFromSegment( MemorySegment vectors, int vectorLength, int vectorPitch, @@ -150,8 +154,8 @@ public float score(int node) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { - bulkScoreFromOrds(ord, nodes, scores, numNodes); + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + return bulkScoreFromOrds(ord, nodes, scores, numNodes); } @Override @@ -181,7 +185,7 @@ float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float b } @Override - protected void bulkScoreFromSegment( + protected float bulkScoreFromSegment( MemorySegment vectors, int vectorLength, int vectorPitch, @@ -194,11 +198,15 @@ protected void bulkScoreFromSegment( var firstVector = vectors.asSlice(firstByteOffset, vectorPitch); Similarities.squareDistanceI7uBulkWithOffsets(vectors, firstVector, dims, vectorPitch, ordinals, numNodes, scores); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { var squareDistance = scores.getAtIndex(ValueLayout.JAVA_FLOAT, i); float adjustedDistance = squareDistance * scoreCorrectionConstant; - scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, 1 / (1f + adjustedDistance)); + float adjustedScore = 1 / (1f + adjustedDistance); + scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, adjustedScore); + max = Math.max(max, adjustedScore); } + return max; } @Override @@ -222,7 +230,7 @@ float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float b } @Override - protected void bulkScoreFromSegment( + protected float bulkScoreFromSegment( MemorySegment vectors, int vectorLength, int vectorPitch, @@ -239,6 +247,7 @@ protected void bulkScoreFromSegment( var aOffset = Float.intBitsToFloat( vectors.asSlice(firstByteOffset + vectorLength, Float.BYTES).getAtIndex(ValueLayout.JAVA_INT_UNALIGNED, 0) ); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { var dotProduct = scores.getAtIndex(ValueLayout.JAVA_FLOAT, i); var secondOrd = ordinals.getAtIndex(ValueLayout.JAVA_INT, i); @@ -247,8 +256,11 @@ protected void bulkScoreFromSegment( vectors.asSlice(secondByteOffset + vectorLength, Float.BYTES).getAtIndex(ValueLayout.JAVA_INT_UNALIGNED, 0) ); float adjustedDistance = dotProduct * scoreCorrectionConstant + aOffset + bOffset; - scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, Math.max((1 + adjustedDistance) / 2, 0f)); + float adjustedScore = Math.max((1 + adjustedDistance) / 2, 0f); + scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, adjustedScore); + max = Math.max(max, adjustedScore); } + return max; } @Override @@ -275,7 +287,7 @@ float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float b } @Override - protected void bulkScoreFromSegment( + protected float bulkScoreFromSegment( MemorySegment vectors, int vectorLength, int vectorPitch, @@ -292,6 +304,7 @@ protected void bulkScoreFromSegment( var aOffset = Float.intBitsToFloat( vectors.asSlice(firstByteOffset + vectorLength, Float.BYTES).getAtIndex(ValueLayout.JAVA_INT_UNALIGNED, 0) ); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { var dotProduct = scores.getAtIndex(ValueLayout.JAVA_FLOAT, i); var secondOrd = ordinals.getAtIndex(ValueLayout.JAVA_INT, i); @@ -302,7 +315,9 @@ protected void bulkScoreFromSegment( float adjustedDistance = dotProduct * scoreCorrectionConstant + aOffset + bOffset; adjustedDistance = adjustedDistance < 0 ? 1 / (1 + -1 * adjustedDistance) : adjustedDistance + 1; scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, adjustedDistance); + max = Math.max(max, adjustedDistance); } + return max; } @Override diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorer.java new file mode 100644 index 0000000000000..17cc0c808ae21 --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorer.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +import java.util.Optional; + +/** + * Outlines the Int7 OSQ query-time scorer. The concrete implementation will + * connect to the native OSQ routines and apply the similarity-specific + * corrections. + */ +public final class Int7uOSQVectorScorer { + + public static Optional create( + VectorSimilarityFunction sim, + QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + // TODO add JDK21 fallback logic and native scorer dispatch + return Optional.empty(); + } + + private Int7uOSQVectorScorer() {} +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorerSupplier.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorerSupplier.java new file mode 100644 index 0000000000000..194a3d88bcf71 --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorerSupplier.java @@ -0,0 +1,334 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +import static org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier.SUPPORTS_HEAP_SEGMENTS; +import static org.elasticsearch.simdvec.internal.Similarities.dotProductI7u; +import static org.elasticsearch.simdvec.internal.Similarities.dotProductI7uBulkWithOffsets; + +/** + * Int7 OSQ scorer supplier backed by {@link MemorySegmentAccessInput} storage. + */ +public abstract sealed class Int7uOSQVectorScorerSupplier implements RandomVectorScorerSupplier permits + Int7uOSQVectorScorerSupplier.DotProductSupplier, Int7uOSQVectorScorerSupplier.EuclideanSupplier, + Int7uOSQVectorScorerSupplier.MaxInnerProductSupplier { + + private static final float LIMIT_SCALE = 1f / ((1 << 7) - 1); + + protected final MemorySegmentAccessInput input; + protected final QuantizedByteVectorValues values; + protected final int dims; + protected final int maxOrd; + + Int7uOSQVectorScorerSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values) { + this.input = input; + this.values = values; + this.dims = values.dimension(); + this.maxOrd = values.size(); + } + + protected abstract float applyCorrections(float rawScore, int ord, QueryContext query) throws IOException; + + protected abstract float applyCorrectionsBulk(MemorySegment scores, MemorySegment ordinals, int numNodes, QueryContext query) + throws IOException; + + protected record QueryContext( + int ord, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) {} + + protected QueryContext createQueryContext(int ord) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + return new QueryContext( + ord, + correctiveTerms.lowerInterval(), + correctiveTerms.upperInterval(), + correctiveTerms.additionalCorrection(), + correctiveTerms.quantizedComponentSum() + ); + } + + protected final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + protected final float scoreFromOrds(QueryContext query, int secondOrd) throws IOException { + int firstOrd = query.ord; + checkOrdinal(firstOrd); + checkOrdinal(secondOrd); + long vectorPitch = getVectorPitch(); + long firstVectorOffset = firstOrd * vectorPitch; + long secondVectorOffset = secondOrd * vectorPitch; + + MemorySegment first = input.segmentSliceOrNull(firstVectorOffset, dims); + MemorySegment second = input.segmentSliceOrNull(secondVectorOffset, dims); + if (first == null || second == null) { + return scoreViaFallback(query, secondOrd, firstVectorOffset, secondVectorOffset); + } + int rawScore = dotProductI7u(first, second, dims); + return applyCorrections(rawScore, secondOrd, query); + } + + protected final float bulkScoreFromOrds(QueryContext query, int[] ordinals, float[] scores, int numNodes) throws IOException { + checkOrdinal(query.ord); + MemorySegment vectors = input.segmentSliceOrNull(0, input.length()); + if (vectors == null) { + float max = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + scores[i] = scoreFromOrds(query, ordinals[i]); + max = Math.max(max, scores[i]); + } + return max; + } + if (SUPPORTS_HEAP_SEGMENTS) { + var ordinalsSeg = MemorySegment.ofArray(ordinals); + var scoresSeg = MemorySegment.ofArray(scores); + computeBulkForQuery(query, vectors, ordinalsSeg, scoresSeg, numNodes); + return applyCorrectionsBulk(scoresSeg, ordinalsSeg, numNodes, query); + } else { + try (Arena arena = Arena.ofConfined()) { + MemorySegment ordinalsSeg = arena.allocate((long) numNodes * Integer.BYTES, Integer.BYTES); + MemorySegment scoresSeg = arena.allocate((long) numNodes * Float.BYTES, Float.BYTES); + MemorySegment.copy(ordinals, 0, ordinalsSeg, ValueLayout.JAVA_INT, 0, numNodes); + computeBulkForQuery(query, vectors, ordinalsSeg, scoresSeg, numNodes); + float max = applyCorrectionsBulk(scoresSeg, ordinalsSeg, numNodes, query); + MemorySegment.copy(scoresSeg, ValueLayout.JAVA_FLOAT, 0, scores, 0, numNodes); + return max; + } + } + } + + private void computeBulkForQuery(QueryContext query, MemorySegment vectors, MemorySegment ordinals, MemorySegment scores, int numNodes) + throws IOException { + long firstByteOffset = query.ord * getVectorPitch(); + MemorySegment firstVector = vectors.asSlice(firstByteOffset, getVectorPitch()); + computeBulk(firstVector, vectors, ordinals, scores, numNodes); + } + + private float scoreViaFallback(QueryContext query, int secondOrd, long firstVectorOffset, long secondVectorOffset) throws IOException { + byte[] a = new byte[dims]; + byte[] b = new byte[dims]; + input.readBytes(firstVectorOffset, a, 0, dims); + input.readBytes(secondVectorOffset, b, 0, dims); + // Just fall back to regular dot-product and apply corrections + int raw = VectorUtil.dotProduct(a, b); + return applyCorrections(raw, secondOrd, query); + } + + protected final void computeBulk( + MemorySegment firstVector, + MemorySegment vectors, + MemorySegment ordinals, + MemorySegment scores, + int numNodes + ) throws IOException { + dotProductI7uBulkWithOffsets(vectors, firstVector, dims, (int) getVectorPitch(), ordinals, numNodes, scores); + } + + protected final long getVectorPitch() { + return dims + 3L * Float.BYTES + Integer.BYTES; + } + + @Override + public UpdateableRandomVectorScorer scorer() { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { + private int ord = -1; + private QueryContext query; + + @Override + public float score(int node) throws IOException { + if (query == null) { + throw new IllegalStateException("scoring ordinal is not set"); + } + return scoreFromOrds(query, node); + } + + @Override + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + if (query == null) { + throw new IllegalStateException("scoring ordinal is not set"); + } + return bulkScoreFromOrds(query, nodes, scores, numNodes); + } + + @Override + public void setScoringOrdinal(int node) throws IOException { + checkOrdinal(node); + ord = node; + query = createQueryContext(node); + } + }; + } + + public QuantizedByteVectorValues get() { + return values; + } + + public static final class DotProductSupplier extends Int7uOSQVectorScorerSupplier { + public DotProductSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values) { + super(input, values); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new DotProductSupplier(input.clone(), values.copy()); + } + + @Override + protected float applyCorrections(float rawScore, int ord, QueryContext query) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float ay = query.lowerInterval; + float ly = (query.upperInterval - ay) * LIMIT_SCALE; + float y1 = query.quantizedComponentSum; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * rawScore; + score += query.additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + return VectorUtil.normalizeToUnitInterval(Math.clamp(score, -1, 1)); + } + + @Override + protected float applyCorrectionsBulk(MemorySegment scoreSeg, MemorySegment ordinalsSeg, int numNodes, QueryContext query) + throws IOException { + float ay = query.lowerInterval; + float ly = (query.upperInterval - ay) * LIMIT_SCALE; + float y1 = query.quantizedComponentSum; + float max = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + float raw = scoreSeg.getAtIndex(ValueLayout.JAVA_FLOAT, i); + int ord = ordinalsSeg.getAtIndex(ValueLayout.JAVA_INT, i); + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float x1 = correctiveTerms.quantizedComponentSum(); + float adjustedScore = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * raw; + adjustedScore += query.additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + float normalized = VectorUtil.normalizeToUnitInterval(Math.clamp(adjustedScore, -1, 1)); + scoreSeg.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalized); + max = Math.max(max, normalized); + } + return max; + } + + } + + public static final class EuclideanSupplier extends Int7uOSQVectorScorerSupplier { + public EuclideanSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values) { + super(input, values); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new EuclideanSupplier(input.clone(), values.copy()); + } + + @Override + protected float applyCorrections(float rawScore, int ord, QueryContext query) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float ay = query.lowerInterval; + float ly = (query.upperInterval - ay) * LIMIT_SCALE; + float y1 = query.quantizedComponentSum; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * rawScore; + score = query.additionalCorrection + correctiveTerms.additionalCorrection() - 2 * score; + return VectorUtil.normalizeDistanceToUnitInterval(Math.max(score, 0f)); + } + + @Override + protected float applyCorrectionsBulk(MemorySegment scoreSeg, MemorySegment ordinalsSeg, int numNodes, QueryContext query) + throws IOException { + float ay = query.lowerInterval; + float ly = (query.upperInterval - ay) * LIMIT_SCALE; + float y1 = query.quantizedComponentSum; + float max = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + float raw = scoreSeg.getAtIndex(ValueLayout.JAVA_FLOAT, i); + int ord = ordinalsSeg.getAtIndex(ValueLayout.JAVA_INT, i); + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * raw; + score = query.additionalCorrection + correctiveTerms.additionalCorrection() - 2 * score; + float normalized = VectorUtil.normalizeDistanceToUnitInterval(Math.max(score, 0f)); + scoreSeg.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalized); + max = Math.max(max, normalized); + } + return max; + } + } + + public static final class MaxInnerProductSupplier extends Int7uOSQVectorScorerSupplier { + public MaxInnerProductSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values) { + super(input, values); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new MaxInnerProductSupplier(input.clone(), values.copy()); + } + + @Override + protected float applyCorrections(float rawScore, int ord, QueryContext query) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float ay = query.lowerInterval; + float ly = (query.upperInterval - ay) * LIMIT_SCALE; + float y1 = query.quantizedComponentSum; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * rawScore; + score += query.additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + return VectorUtil.scaleMaxInnerProductScore(score); + } + + @Override + protected float applyCorrectionsBulk(MemorySegment scoreSeg, MemorySegment ordinalsSeg, int numNodes, QueryContext query) + throws IOException { + float ay = query.lowerInterval; + float ly = (query.upperInterval - ay) * LIMIT_SCALE; + float y1 = query.quantizedComponentSum; + float max = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + float raw = scoreSeg.getAtIndex(ValueLayout.JAVA_FLOAT, i); + int ord = ordinalsSeg.getAtIndex(ValueLayout.JAVA_INT, i); + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * raw; + score += query.additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + float normalizedScore = VectorUtil.scaleMaxInnerProductScore(score); + scoreSeg.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalizedScore); + max = Math.max(max, normalizedScore); + } + return max; + } + } +} diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/ByteVectorScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/ByteVectorScorer.java index e5d7718327a8e..c81e37b869f27 100644 --- a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/ByteVectorScorer.java +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/ByteVectorScorer.java @@ -113,19 +113,22 @@ public float score(int node) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); if (vectorsSeg == null) { - super.bulkScore(nodes, scores, numNodes); + return super.bulkScore(nodes, scores, numNodes); } else { var ordinalsSeg = MemorySegment.ofArray(nodes); var scoresSeg = MemorySegment.ofArray(scores); dotProductI8BulkWithOffsets(vectorsSeg, query, dimensions, vectorByteSize, ordinalsSeg, numNodes, scoresSeg); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { scores[i] = normalize(scores[i]); + max = Math.max(max, scores[i]); } + return max; } } } @@ -147,19 +150,22 @@ public float score(int node) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); if (vectorsSeg == null) { - super.bulkScore(nodes, scores, numNodes); + return super.bulkScore(nodes, scores, numNodes); } else { var ordinalsSeg = MemorySegment.ofArray(nodes); var scoresSeg = MemorySegment.ofArray(scores); cosineI8BulkWithOffsets(vectorsSeg, query, dimensions, vectorByteSize, ordinalsSeg, numNodes, scoresSeg); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { scores[i] = normalize(scores[i]); + max = Math.max(max, scores[i]); } + return max; } } } @@ -177,19 +183,22 @@ public float score(int node) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); if (vectorsSeg == null) { - super.bulkScore(nodes, scores, numNodes); + return super.bulkScore(nodes, scores, numNodes); } else { var ordinalsSeg = MemorySegment.ofArray(nodes); var scoresSeg = MemorySegment.ofArray(scores); squareDistanceI8BulkWithOffsets(vectorsSeg, query, dimensions, vectorByteSize, ordinalsSeg, numNodes, scoresSeg); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { scores[i] = VectorUtil.normalizeDistanceToUnitInterval(scores[i]); + max = Math.max(max, scores[i]); } + return max; } } } @@ -207,19 +216,22 @@ public float score(int node) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); if (vectorsSeg == null) { - super.bulkScore(nodes, scores, numNodes); + return super.bulkScore(nodes, scores, numNodes); } else { var ordinalsSeg = MemorySegment.ofArray(nodes); var scoresSeg = MemorySegment.ofArray(scores); dotProductI8BulkWithOffsets(vectorsSeg, query, dimensions, vectorByteSize, ordinalsSeg, numNodes, scoresSeg); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { scores[i] = VectorUtil.scaleMaxInnerProductScore(scores[i]); + max = Math.max(max, scores[i]); } + return max; } } } diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/FloatVectorScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/FloatVectorScorer.java index 7dd48c7e06696..f374995128916 100644 --- a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/FloatVectorScorer.java +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/FloatVectorScorer.java @@ -104,19 +104,22 @@ public float score(int node) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); if (vectorsSeg == null) { - super.bulkScore(nodes, scores, numNodes); + return super.bulkScore(nodes, scores, numNodes); } else { var ordinalsSeg = MemorySegment.ofArray(nodes); var scoresSeg = MemorySegment.ofArray(scores); dotProductF32BulkWithOffsets(vectorsSeg, query, dimensions, vectorByteSize, ordinalsSeg, numNodes, scoresSeg); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { scores[i] = VectorUtil.normalizeToUnitInterval(scores[i]); + max = Math.max(max, scores[i]); } + return max; } } } @@ -134,19 +137,22 @@ public float score(int node) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); if (vectorsSeg == null) { - super.bulkScore(nodes, scores, numNodes); + return super.bulkScore(nodes, scores, numNodes); } else { var ordinalsSeg = MemorySegment.ofArray(nodes); var scoresSeg = MemorySegment.ofArray(scores); squareDistanceF32BulkWithOffsets(vectorsSeg, query, dimensions, vectorByteSize, ordinalsSeg, numNodes, scoresSeg); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { scores[i] = VectorUtil.normalizeDistanceToUnitInterval(scores[i]); + max = Math.max(max, scores[i]); } + return max; } } } @@ -164,19 +170,22 @@ public float score(int node) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); if (vectorsSeg == null) { - super.bulkScore(nodes, scores, numNodes); + return super.bulkScore(nodes, scores, numNodes); } else { var ordinalsSeg = MemorySegment.ofArray(nodes); var scoresSeg = MemorySegment.ofArray(scores); dotProductF32BulkWithOffsets(vectorsSeg, query, dimensions, vectorByteSize, ordinalsSeg, numNodes, scoresSeg); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { scores[i] = VectorUtil.scaleMaxInnerProductScore(scores[i]); + max = Math.max(max, scores[i]); } + return max; } } } diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorer.java index 52cb8f692db1c..baddacd0b8aa7 100644 --- a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorer.java +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorer.java @@ -117,10 +117,10 @@ public float score(int node) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); if (vectorsSeg == null) { - super.bulkScore(nodes, scores, numNodes); + return super.bulkScore(nodes, scores, numNodes); } else { var ordinalsSeg = MemorySegment.ofArray(nodes); var scoresSeg = MemorySegment.ofArray(scores); @@ -128,6 +128,7 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcept var vectorPitch = vectorByteSize + Float.BYTES; dotProductI7uBulkWithOffsets(vectorsSeg, query, vectorByteSize, vectorPitch, ordinalsSeg, numNodes, scoresSeg); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { var dotProduct = scores[i]; var secondOrd = nodes[i]; @@ -135,7 +136,9 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcept var nodeCorrection = Float.intBitsToFloat(input.readInt(secondByteOffset + vectorByteSize)); float adjustedDistance = dotProduct * scoreCorrectionConstant + queryCorrection + nodeCorrection; scores[i] = VectorUtil.normalizeToUnitInterval(adjustedDistance); + max = Math.max(max, scores[i]); } + return max; } } } @@ -154,10 +157,10 @@ public float score(int node) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); if (vectorsSeg == null) { - super.bulkScore(nodes, scores, numNodes); + return super.bulkScore(nodes, scores, numNodes); } else { var ordinalsSeg = MemorySegment.ofArray(nodes); var scoresSeg = MemorySegment.ofArray(scores); @@ -165,11 +168,14 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcept var vectorPitch = vectorByteSize + Float.BYTES; squareDistanceI7uBulkWithOffsets(vectorsSeg, query, vectorByteSize, vectorPitch, ordinalsSeg, numNodes, scoresSeg); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { var squareDistance = scores[i]; float adjustedDistance = squareDistance * scoreCorrectionConstant; scores[i] = VectorUtil.normalizeDistanceToUnitInterval(adjustedDistance); + max = Math.max(max, scores[i]); } + return max; } } } @@ -191,10 +197,10 @@ public float score(int node) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); if (vectorsSeg == null) { - super.bulkScore(nodes, scores, numNodes); + return super.bulkScore(nodes, scores, numNodes); } else { var ordinalsSeg = MemorySegment.ofArray(nodes); var scoresSeg = MemorySegment.ofArray(scores); @@ -202,6 +208,7 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcept var vectorPitch = vectorByteSize + Float.BYTES; dotProductI7uBulkWithOffsets(vectorsSeg, query, vectorByteSize, vectorPitch, ordinalsSeg, numNodes, scoresSeg); + float max = Float.NEGATIVE_INFINITY; for (int i = 0; i < numNodes; ++i) { var dotProduct = scores[i]; var secondOrd = nodes[i]; @@ -209,7 +216,9 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcept var nodeCorrection = Float.intBitsToFloat(input.readInt(secondByteOffset + vectorByteSize)); float adjustedDistance = dotProduct * scoreCorrectionConstant + queryCorrection + nodeCorrection; scores[i] = VectorUtil.scaleMaxInnerProductScore(adjustedDistance); + max = Math.max(max, scores[i]); } + return max; } } } diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorer.java new file mode 100644 index 0000000000000..835fd623ec235 --- /dev/null +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorer.java @@ -0,0 +1,334 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.elasticsearch.simdvec.MemorySegmentAccessInputAccess; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.util.Optional; + +import static org.elasticsearch.simdvec.internal.Similarities.dotProductI7u; +import static org.elasticsearch.simdvec.internal.Similarities.dotProductI7uBulkWithOffsets; + +/** + * JDK-22+ implementation for Int7 OSQ query-time scorers. + */ +public abstract sealed class Int7uOSQVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer permits + Int7uOSQVectorScorer.DotProductScorer, Int7uOSQVectorScorer.EuclideanScorer, Int7uOSQVectorScorer.MaxInnerProductScorer { + + private static final float LIMIT_SCALE = 1f / ((1 << 7) - 1); + + public static Optional create( + VectorSimilarityFunction sim, + QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + if (quantizedQuery.length != values.getVectorByteLength()) { + throw new IllegalArgumentException( + "quantized query length " + quantizedQuery.length + " differs from vector byte length " + values.getVectorByteLength() + ); + } + + var input = values.getSlice(); + if (input == null) { + return Optional.empty(); + } + input = FilterIndexInput.unwrapOnlyTest(input); + input = MemorySegmentAccessInputAccess.unwrap(input); + if ((input instanceof MemorySegmentAccessInput) == false) { + return Optional.empty(); + } + MemorySegmentAccessInput msInput = (MemorySegmentAccessInput) input; + checkInvariants(values.size(), values.getVectorByteLength(), input); + + return switch (sim) { + case COSINE, DOT_PRODUCT -> Optional.of( + new DotProductScorer( + msInput, + values, + quantizedQuery, + lowerInterval, + upperInterval, + additionalCorrection, + quantizedComponentSum + ) + ); + case EUCLIDEAN -> Optional.of( + new EuclideanScorer( + msInput, + values, + quantizedQuery, + lowerInterval, + upperInterval, + additionalCorrection, + quantizedComponentSum + ) + ); + case MAXIMUM_INNER_PRODUCT -> Optional.of( + new MaxInnerProductScorer( + msInput, + values, + quantizedQuery, + lowerInterval, + upperInterval, + additionalCorrection, + quantizedComponentSum + ) + ); + }; + } + + final QuantizedByteVectorValues values; + final MemorySegmentAccessInput input; + final int vectorByteSize; + final MemorySegment query; + final float lowerInterval; + final float upperInterval; + final float additionalCorrection; + final int quantizedComponentSum; + byte[] scratch; + + Int7uOSQVectorScorer( + MemorySegmentAccessInput input, + QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + super(values); + this.values = values; + this.input = input; + this.vectorByteSize = values.getVectorByteLength(); + this.query = MemorySegment.ofArray(quantizedQuery); + this.lowerInterval = lowerInterval; + this.upperInterval = upperInterval; + this.additionalCorrection = additionalCorrection; + this.quantizedComponentSum = quantizedComponentSum; + } + + abstract float applyCorrections(float rawScore, int ord) throws IOException; + + abstract float applyCorrectionsBulk(float[] scores, int[] ords, int numNodes) throws IOException; + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + int dotProduct = dotProductI7u(query, getSegment(node), vectorByteSize); + return applyCorrections(dotProduct, node); + } + + @Override + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); + if (vectorsSeg == null) { + return super.bulkScore(nodes, scores, numNodes); + } else { + var ordinalsSeg = MemorySegment.ofArray(nodes); + var scoresSeg = MemorySegment.ofArray(scores); + + var vectorPitch = vectorByteSize + 3 * Float.BYTES + Integer.BYTES; + dotProductI7uBulkWithOffsets(vectorsSeg, query, vectorByteSize, vectorPitch, ordinalsSeg, numNodes, scoresSeg); + return applyCorrectionsBulk(scores, nodes, numNodes); + } + } + + final MemorySegment getSegment(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = (long) ord * (vectorByteSize + 3 * Float.BYTES + Integer.BYTES); + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); + if (seg == null) { + if (scratch == null) { + scratch = new byte[vectorByteSize]; + } + input.readBytes(byteOffset, scratch, 0, vectorByteSize); + seg = MemorySegment.ofArray(scratch); + } + return seg; + } + + final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd()) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + public static final class DotProductScorer extends Int7uOSQVectorScorer { + public DotProductScorer( + MemorySegmentAccessInput input, + QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + super(input, values, quantizedQuery, lowerInterval, upperInterval, additionalCorrection, quantizedComponentSum); + } + + @Override + float applyCorrections(float rawScore, int ord) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + float x1 = correctiveTerms.quantizedComponentSum(); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float ay = lowerInterval; + float ly = (upperInterval - ay) * LIMIT_SCALE; + float y1 = quantizedComponentSum; + float score = ax * ay * values.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * rawScore; + score += additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + score = Math.clamp(score, -1, 1); + return VectorUtil.normalizeToUnitInterval(score); + } + + @Override + float applyCorrectionsBulk(float[] scores, int[] ords, int numNodes) throws IOException { + float ay = lowerInterval; + float ly = (upperInterval - ay) * LIMIT_SCALE; + float y1 = quantizedComponentSum; + float maxScore = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + int ord = ords[i]; + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * values.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * scores[i]; + score += additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + score = Math.clamp(score, -1, 1); + scores[i] = VectorUtil.normalizeToUnitInterval(score); + if (scores[i] > maxScore) { + maxScore = scores[i]; + } + } + return maxScore; + } + } + + public static final class EuclideanScorer extends Int7uOSQVectorScorer { + public EuclideanScorer( + MemorySegmentAccessInput input, + QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + super(input, values, quantizedQuery, lowerInterval, upperInterval, additionalCorrection, quantizedComponentSum); + } + + @Override + float applyCorrections(float rawScore, int ord) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + float x1 = correctiveTerms.quantizedComponentSum(); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float ay = lowerInterval; + float ly = (upperInterval - ay) * LIMIT_SCALE; + float y1 = quantizedComponentSum; + float score = ax * ay * values.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * rawScore; + score = additionalCorrection + correctiveTerms.additionalCorrection() - 2 * score; + return VectorUtil.normalizeDistanceToUnitInterval(Math.max(score, 0f)); + } + + @Override + float applyCorrectionsBulk(float[] scores, int[] ords, int numNodes) throws IOException { + float ay = lowerInterval; + float ly = (upperInterval - ay) * LIMIT_SCALE; + float y1 = quantizedComponentSum; + float maxScore = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + int ord = ords[i]; + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * values.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * scores[i]; + score = additionalCorrection + correctiveTerms.additionalCorrection() - 2 * score; + scores[i] = VectorUtil.normalizeDistanceToUnitInterval(Math.max(score, 0f)); + if (scores[i] > maxScore) { + maxScore = scores[i]; + } + } + return maxScore; + } + } + + public static final class MaxInnerProductScorer extends Int7uOSQVectorScorer { + public MaxInnerProductScorer( + MemorySegmentAccessInput input, + QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + super(input, values, quantizedQuery, lowerInterval, upperInterval, additionalCorrection, quantizedComponentSum); + } + + @Override + float applyCorrections(float rawScore, int ord) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + float x1 = correctiveTerms.quantizedComponentSum(); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float ay = lowerInterval; + float ly = (upperInterval - ay) * LIMIT_SCALE; + float y1 = quantizedComponentSum; + float score = ax * ay * values.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * rawScore; + score += additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + return VectorUtil.scaleMaxInnerProductScore(score); + } + + @Override + float applyCorrectionsBulk(float[] scores, int[] ords, int numNodes) throws IOException { + float ay = lowerInterval; + float ly = (upperInterval - ay) * LIMIT_SCALE; + float y1 = quantizedComponentSum; + float maxScore = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + int ord = ords[i]; + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * values.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * scores[i]; + score += additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + scores[i] = VectorUtil.scaleMaxInnerProductScore(score); + if (scores[i] > maxScore) { + maxScore = scores[i]; + } + } + return maxScore; + } + } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + long vectorPitch = vectorByteLength + 3L * Float.BYTES + Integer.BYTES; + if (input.length() < vectorPitch * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } +} diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java index 3905c462d29ca..259e5bafb765e 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java @@ -23,11 +23,11 @@ public abstract class AbstractVectorTestCase extends ESTestCase { - static Optional factory; + static Optional factory; @BeforeClass public static void getVectorScorerFactory() { - factory = VectorScorerFactory.instance(); + factory = org.elasticsearch.simdvec.VectorScorerFactory.instance(); } protected AbstractVectorTestCase() { diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7SQVectorScorerFactoryTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7SQVectorScorerFactoryTests.java index a486abb3a68c8..790657e7740c8 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7SQVectorScorerFactoryTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7SQVectorScorerFactoryTests.java @@ -12,7 +12,6 @@ import com.carrotsearch.randomizedtesting.generators.RandomNumbers; import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; -import org.apache.lucene.codecs.lucene99.OffHeapQuantizedByteVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; @@ -24,6 +23,7 @@ import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity; import org.apache.lucene.util.quantization.ScalarQuantizer; +import org.elasticsearch.index.codec.vectors.OffHeapQuantizedByteVectorValues; import java.io.IOException; import java.util.Arrays; diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7uOSQVectorScorerFactoryTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7uOSQVectorScorerFactoryTests.java new file mode 100644 index 0000000000000..04fb45f1f22dd --- /dev/null +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7uOSQVectorScorerFactoryTests.java @@ -0,0 +1,893 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec; + +import com.carrotsearch.randomizedtesting.generators.RandomNumbers; + +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorScorer; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat; +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.core.SuppressForbidden; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.function.IntFunction; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.elasticsearch.simdvec.VectorSimilarityType.DOT_PRODUCT; +import static org.elasticsearch.simdvec.VectorSimilarityType.EUCLIDEAN; +import static org.elasticsearch.simdvec.VectorSimilarityType.MAXIMUM_INNER_PRODUCT; +import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; +import static org.hamcrest.Matchers.equalTo; + +public class Int7uOSQVectorScorerFactoryTests extends org.elasticsearch.simdvec.AbstractVectorTestCase { + private static final float LIMIT_SCALE = 1f / ((1 << 7) - 1); + + @SuppressForbidden(reason = "require usage of OptimizedScalarQuantizer") + private static OptimizedScalarQuantizer scalarQuantizer(VectorSimilarityFunction sim) { + return new OptimizedScalarQuantizer(sim); + } + + // bounds of the range of values that can be seen by int7 scalar quantized vectors + static final byte MIN_INT7_VALUE = 0; + static final byte MAX_INT7_VALUE = 127; + + // Tests that the provider instance is present or not on expected platforms/architectures + public void testSupport() { + supported(); + } + + public void testSimple() throws IOException { + testSimpleImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE); + } + + public void testSimpleMaxChunkSizeSmall() throws IOException { + long maxChunkSize = randomLongBetween(4, 16); + logger.info("maxChunkSize=" + maxChunkSize); + testSimpleImpl(maxChunkSize); + } + + void testSimpleImpl(long maxChunkSize) throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testSimpleImpl"), maxChunkSize)) { + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var scalarQuantizer = scalarQuantizer(sim.function()); + for (int dims : List.of(31, 32, 33)) { + // dimensions that cross the scalar / native boundary (stride) + byte[] vec1 = new byte[dims]; + byte[] vec2 = new byte[dims]; + float[] query1 = new float[dims]; + float[] query2 = new float[dims]; + float[] centroid = new float[dims]; + float centroidDP = 0f; + OptimizedScalarQuantizer.QuantizationResult vec1Correction, vec2Correction; + String fileName = "testSimpleImpl-" + sim + "-" + dims + ".vex"; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < dims; i++) { + query1[i] = (float) i; + query2[i] = (float) (dims - i); + centroid[i] = (query1[i] + query2[i]) / 2f; + centroidDP += centroid[i] * centroid[i]; + } + vec1Correction = scalarQuantizer.scalarQuantize(query1, vec1, (byte) 7, centroid); + vec2Correction = scalarQuantizer.scalarQuantize(query2, vec2, (byte) 7, centroid); + out.writeBytes(vec1, 0, vec1.length); + out.writeInt(Float.floatToIntBits(vec1Correction.lowerInterval())); + out.writeInt(Float.floatToIntBits(vec1Correction.upperInterval())); + out.writeInt(Float.floatToIntBits(vec1Correction.additionalCorrection())); + out.writeInt(vec1Correction.quantizedComponentSum()); + out.writeBytes(vec2, 0, vec2.length); + out.writeInt(Float.floatToIntBits(vec2Correction.lowerInterval())); + out.writeInt(Float.floatToIntBits(vec2Correction.upperInterval())); + out.writeInt(Float.floatToIntBits(vec2Correction.additionalCorrection())); + out.writeInt(vec2Correction.quantizedComponentSum()); + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = vectorValues(dims, 2, centroid, centroidDP, in, sim.function()); + float expected = luceneScore(sim, vec1, vec2, centroidDP, vec1Correction, vec2Correction); + + var luceneSupplier = luceneScoreSupplier(values, sim.function()).scorer(); + luceneSupplier.setScoringOrdinal(1); + assertFloatEquals(expected, luceneSupplier.score(0), 1e-6f); + var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(1); + assertFloatEquals(expected, scorer.score(0), 1e-6f); + + if (supportsHeapSegments()) { + var qScorer = factory.getInt7uOSQVectorScorer( + sim.function(), + values, + vec2, + vec2Correction.lowerInterval(), + vec2Correction.upperInterval(), + vec2Correction.additionalCorrection(), + vec2Correction.quantizedComponentSum() + ).get(); + assertFloatEquals(expected, qScorer.score(0), 1e-6f); + } + } + } + } + } + } + + public void testRandom() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_RANDOM_INT7_FUNC); + } + + public void testRandomMaxChunkSizeSmall() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + long maxChunkSize = randomLongBetween(32, 128); + logger.info("maxChunkSize=" + maxChunkSize); + testRandomSupplier(maxChunkSize, BYTE_ARRAY_RANDOM_INT7_FUNC); + } + + public void testRandomMax() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MAX_INT7_FUNC); + } + + public void testRandomMin() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MIN_INT7_FUNC); + } + + void testRandomSupplier(long maxChunkSize, IntFunction byteArraySupplier) throws IOException { + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testRandom"), maxChunkSize)) { + final int dims = randomIntBetween(1, 4096); + final int size = randomIntBetween(2, 100); + final byte[][] vectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] quantizationResults = new OptimizedScalarQuantizer.QuantizationResult[size]; + final float[] centroid = new float[dims]; + + String fileName = "testRandom-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var vec = byteArraySupplier.apply(dims); + int componentSum = 0; + for (int d = 0; d < dims; d++) { + componentSum += Byte.toUnsignedInt(vec[d]); + } + float lowerInterval = randomFloat(); + float upperInterval = randomFloat() + lowerInterval; + quantizationResults[i] = new OptimizedScalarQuantizer.QuantizationResult( + lowerInterval, + upperInterval, + randomFloat(), + componentSum + ); + out.writeBytes(vec, 0, vec.length); + out.writeInt(Float.floatToIntBits(lowerInterval)); + out.writeInt(Float.floatToIntBits(upperInterval)); + out.writeInt(Float.floatToIntBits(quantizationResults[i].additionalCorrection())); + out.writeInt(componentSum); + vectors[i] = vec; + } + } + for (int i = 0; i < dims; i++) { + centroid[i] = randomFloat(); + } + float centroidDP = VectorUtil.dotProduct(centroid, centroid); + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok. + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, centroid, centroidDP, in, sim.function()); + float expected = luceneScore( + sim, + vectors[idx0], + vectors[idx1], + centroidDP, + quantizationResults[idx0], + quantizationResults[idx1] + ); + var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx1); + assertFloatEquals(expected, scorer.score(idx0), 1e-6f); + } + } + } + } + } + + public void testRandomScorer() throws IOException { + testRandomScorerImpl( + MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, + org.elasticsearch.simdvec.Int7SQVectorScorerFactoryTests.FLOAT_ARRAY_RANDOM_FUNC + ); + } + + public void testRandomScorerMax() throws IOException { + testRandomScorerImpl( + MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, + org.elasticsearch.simdvec.Int7SQVectorScorerFactoryTests.FLOAT_ARRAY_MAX_FUNC + ); + } + + public void testRandomScorerChunkSizeSmall() throws IOException { + long maxChunkSize = randomLongBetween(32, 128); + logger.info("maxChunkSize=" + maxChunkSize); + testRandomScorerImpl(maxChunkSize, FLOAT_ARRAY_RANDOM_FUNC); + } + + void testRandomScorerImpl(long maxChunkSize, IntFunction floatArraySupplier) throws IOException { + assumeTrue("scorer only supported on JDK 22+", Runtime.version().feature() >= 22); + assumeTrue(notSupportedMsg(), supported()); + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testRandom"), maxChunkSize)) { + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var scalarQuantizer = new OptimizedScalarQuantizer(sim.function()); + final int dims = randomIntBetween(1, 4096); + final int size = randomIntBetween(2, 100); + final float[] centroid = new float[dims]; + for (int i = 0; i < dims; i++) { + centroid[i] = randomFloat(); + } + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final float[][] vectors = new float[size][]; + final byte[][] qVectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + + String fileName = "testRandom-" + sim + "-" + dims + ".vex"; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + vectors[i] = floatArraySupplier.apply(dims); + qVectors[i] = new byte[dims]; + corrections[i] = scalarQuantizer.scalarQuantize(vectors[i], qVectors[i], (byte) 7, centroid); + out.writeBytes(qVectors[i], 0, qVectors[i].length); + out.writeInt(Float.floatToIntBits(corrections[i].lowerInterval())); + out.writeInt(Float.floatToIntBits(corrections[i].upperInterval())); + out.writeInt(Float.floatToIntBits(corrections[i].additionalCorrection())); + out.writeInt(corrections[i].quantizedComponentSum()); + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); + var values = vectorValues(dims, size, centroid, centroidDP, in, sim.function()); + + var expected = luceneScore(sim, qVectors[idx0], qVectors[idx1], centroidDP, corrections[idx0], corrections[idx1]); + var scorer = factory.getInt7uOSQVectorScorer( + sim.function(), + values, + qVectors[idx0], + corrections[idx0].lowerInterval(), + corrections[idx0].upperInterval(), + corrections[idx0].additionalCorrection(), + corrections[idx0].quantizedComponentSum() + ).get(); + assertFloatEquals(expected, scorer.score(idx1), 1e-6f); + } + } + } + } + } + + public void testRandomSlice() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + testRandomSliceImpl(30, 64, 1, BYTE_ARRAY_RANDOM_INT7_FUNC); + } + + void testRandomSliceImpl(int dims, long maxChunkSize, int initialPadding, IntFunction byteArraySupplier) throws IOException { + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testRandomSliceImpl"), maxChunkSize)) { + for (int times = 0; times < TIMES; times++) { + final int size = randomIntBetween(2, 100); + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final byte[][] vectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + + String fileName = "testRandomSliceImpl-" + times + "-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + byte[] ba = new byte[initialPadding]; + out.writeBytes(ba, 0, ba.length); + for (int i = 0; i < size; i++) { + var vec = byteArraySupplier.apply(dims); + var correction = randomCorrection(vec); + writeVectorWithCorrection(out, vec, correction); + vectors[i] = vec; + corrections[i] = correction; + } + } + try ( + var outter = dir.openInput(fileName, IOContext.DEFAULT); + var in = outter.slice("slice", initialPadding, outter.length() - initialPadding) + ) { + for (int itrs = 0; itrs < TIMES / 10; itrs++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok. + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, centroid, centroidDP, in, sim.function()); + float expected = luceneScore( + sim, + vectors[idx0], + vectors[idx1], + centroidDP, + corrections[idx0], + corrections[idx1] + ); + var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx1); + assertFloatEquals(expected, scorer.score(idx0), 1e-6f); + } + } + } + } + } + } + + // Tests with a large amount of data (> 2GB), which ensures that data offsets do not overflow + @Nightly + public void testLarge() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testLarge"))) { + final int dims = 8192; + final int size = 262144; + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + + String fileName = "testLarge-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var vec = vector(i, dims); + var correction = randomCorrection(vec); + writeVectorWithCorrection(out, vec, correction); + corrections[i] = correction; + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = size - 1; + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, centroid, centroidDP, in, sim.function()); + float expected = luceneScore( + sim, + vector(idx0, dims), + vector(idx1, dims), + centroidDP, + corrections[idx0], + corrections[idx1] + ); + var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx1); + assertFloatEquals(expected, scorer.score(idx0), 1e-6f); + } + } + } + } + } + + // Test that the scorer works well when the IndexInput is greater than the directory segment chunk size + public void testDatasetGreaterThanChunkSize() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testDatasetGreaterThanChunkSize"), 8192)) { + final int dims = 1024; + final int size = 128; + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final byte[][] vectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + + String fileName = "testDatasetGreaterThanChunkSize-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var vec = vector(i, dims); + var correction = randomCorrection(vec); + writeVectorWithCorrection(out, vec, correction); + vectors[i] = vec; + corrections[i] = correction; + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = size - 1; + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, centroid, centroidDP, in, sim.function()); + float expected = luceneScore(sim, vectors[idx0], vectors[idx1], centroidDP, corrections[idx0], corrections[idx1]); + var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx1); + assertFloatEquals(expected, scorer.score(idx0), 1e-6f); + } + } + } + } + } + + public void testBulk() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + final int dims = 1024; + final int size = randomIntBetween(1, 102); + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final byte[][] vectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + // Set maxChunkSize to be less than dims * size + try (Directory dir = new MMapDirectory(createTempDir("testBulk"))) { + String fileName = "testBulk-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var vec = vector(i, dims); + var correction = randomCorrection(vec); + writeVectorWithCorrection(out, vec, correction); + vectors[i] = vec; + corrections[i] = correction; + } + } + + List ids = IntStream.range(0, size).boxed().collect(Collectors.toList()); + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int[] nodes = shuffledList(ids).stream().mapToInt(i -> i).toArray(); + for (var sim : List.of(EUCLIDEAN)) { + QuantizedByteVectorValues values = vectorValues(dims, size, centroid, centroidDP, in, sim.function()); + float[] expected = new float[nodes.length]; + float[] scores = new float[nodes.length]; + var referenceScorer = luceneScoreSupplier(values, sim.function()).scorer(); + referenceScorer.setScoringOrdinal(idx0); + referenceScorer.bulkScore(nodes, expected, nodes.length); + var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).orElseThrow(); + var testScorer = supplier.scorer(); + testScorer.setScoringOrdinal(idx0); + testScorer.bulkScore(nodes, scores, nodes.length); + // applying the corrections in even a slightly different order can impact the score + // account for this during bulk scoring + assertFloatArrayEquals(expected, scores, 2e-5f); + } + } + } + } + } + + public void testBulkWithDatasetGreaterThanChunkSize() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + final int dims = 1024; + final int size = 128; + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final byte[][] vectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + // Set maxChunkSize to be less than dims * size + try (Directory dir = new MMapDirectory(createTempDir("testBulkWithDatasetGreaterThanChunkSize"), 8192)) { + String fileName = "testBulkWithDatasetGreaterThanChunkSize-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var vec = vector(i, dims); + var correction = randomCorrection(vec); + writeVectorWithCorrection(out, vec, correction); + vectors[i] = vec; + corrections[i] = correction; + } + } + + List ids = IntStream.range(0, size).boxed().collect(Collectors.toList()); + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int[] nodes = shuffledList(ids).stream().mapToInt(i -> i).toArray(); + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + QuantizedByteVectorValues values = vectorValues(dims, size, centroid, centroidDP, in, sim.function()); + float[] expected = new float[nodes.length]; + float[] scores = new float[nodes.length]; + var referenceScorer = luceneScoreSupplier(values, sim.function()).scorer(); + referenceScorer.setScoringOrdinal(idx0); + referenceScorer.bulkScore(nodes, expected, nodes.length); + var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).orElseThrow(); + var testScorer = supplier.scorer(); + testScorer.setScoringOrdinal(idx0); + testScorer.bulkScore(nodes, scores, nodes.length); + assertFloatArrayEquals(expected, scores, 1e-6f); + } + } + } + } + } + + public void testRace() throws Exception { + testRaceImpl(DOT_PRODUCT); + testRaceImpl(EUCLIDEAN); + testRaceImpl(MAXIMUM_INNER_PRODUCT); + } + + // Tests that copies in threads do not interfere with each other + void testRaceImpl(org.elasticsearch.simdvec.VectorSimilarityType sim) throws Exception { + assumeTrue(notSupportedMsg(), supported()); + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + final long maxChunkSize = 32; + final int dims = 34; // dimensions that are larger than the chunk size, to force fallback + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + byte[] vec1 = new byte[dims]; + byte[] vec2 = new byte[dims]; + IntStream.range(0, dims).forEach(i -> vec1[i] = 1); + IntStream.range(0, dims).forEach(i -> vec2[i] = 2); + var correction1 = randomCorrection(vec1); + var correction2 = randomCorrection(vec2); + try (Directory dir = new MMapDirectory(createTempDir("testRace"), maxChunkSize)) { + String fileName = "testRace-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + writeVectorWithCorrection(out, vec1, correction1); + writeVectorWithCorrection(out, vec1, correction1); + writeVectorWithCorrection(out, vec2, correction2); + writeVectorWithCorrection(out, vec2, correction2); + } + var expectedScore1 = luceneScore(sim, vec1, vec1, centroidDP, correction1, correction1); + var expectedScore2 = luceneScore(sim, vec2, vec2, centroidDP, correction2, correction2); + + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = vectorValues(dims, 4, centroid, centroidDP, in, sim.function()); + var scoreSupplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); + var tasks = List.>>of( + new ScoreCallable(scoreSupplier.copy().scorer(), 0, 1, expectedScore1), + new ScoreCallable(scoreSupplier.copy().scorer(), 2, 3, expectedScore2) + ); + var executor = Executors.newFixedThreadPool(2); + var results = executor.invokeAll(tasks); + executor.shutdown(); + assertTrue(executor.awaitTermination(60, TimeUnit.SECONDS)); + assertThat(results.stream().filter(Predicate.not(Future::isDone)).count(), equalTo(0L)); + for (var res : results) { + assertThat("Unexpected exception" + res.get(), res.get(), isEmpty()); + } + } + } + } + + static class ScoreCallable implements Callable> { + + final UpdateableRandomVectorScorer scorer; + final int ord; + final float expectedScore; + + ScoreCallable(UpdateableRandomVectorScorer scorer, int queryOrd, int ord, float expectedScore) { + try { + this.scorer = scorer; + this.scorer.setScoringOrdinal(queryOrd); + this.ord = ord; + this.expectedScore = expectedScore; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Optional call() { + try { + for (int i = 0; i < 100; i++) { + assertFloatEquals(expectedScore, scorer.score(ord), 1e-6f); + } + } catch (Throwable t) { + return Optional.of(t); + } + return Optional.empty(); + } + } + + private static OptimizedScalarQuantizer.QuantizationResult randomCorrection(byte[] vec) { + int componentSum = 0; + for (byte value : vec) { + componentSum += Byte.toUnsignedInt(value); + } + float lowerInterval = randomFloat(); + float upperInterval = lowerInterval + randomFloat(); + return new OptimizedScalarQuantizer.QuantizationResult(lowerInterval, upperInterval, randomFloat(), componentSum); + } + + private static void writeVectorWithCorrection(IndexOutput out, byte[] vec, OptimizedScalarQuantizer.QuantizationResult correction) + throws IOException { + out.writeBytes(vec, 0, vec.length); + out.writeInt(Float.floatToIntBits(correction.lowerInterval())); + out.writeInt(Float.floatToIntBits(correction.upperInterval())); + out.writeInt(Float.floatToIntBits(correction.additionalCorrection())); + out.writeInt(correction.quantizedComponentSum()); + } + + QuantizedByteVectorValues vectorValues( + int dims, + int size, + float[] centroid, + float centroidDP, + IndexInput in, + VectorSimilarityFunction sim + ) throws IOException { + var slice = in.slice("values", 0, in.length()); + return new DenseOffHeapScalarQuantizedVectorValues(dims, size, sim, slice, centroid, centroidDP); + } + + /** Computes the score using the Lucene implementation. */ + public float luceneScore( + org.elasticsearch.simdvec.VectorSimilarityType similarityFunc, + byte[] a, + byte[] b, + float centroidDP, + OptimizedScalarQuantizer.QuantizationResult aCorrection, + OptimizedScalarQuantizer.QuantizationResult bCorrection + ) { + OSQScorer scorer = OSQScorer.fromSimilarity(similarityFunc); + return scorer.score(a, b, centroidDP, aCorrection, bCorrection); + } + + private abstract static class OSQScorer { + static OSQScorer fromSimilarity(org.elasticsearch.simdvec.VectorSimilarityType sim) { + return switch (sim) { + case DOT_PRODUCT -> new DotProductOSQScorer(); + case MAXIMUM_INNER_PRODUCT -> new MaxInnerProductOSQScorer(); + case EUCLIDEAN -> new EuclideanOSQScorer(); + default -> throw new IllegalArgumentException("Unsupported similarity: " + sim); + }; + } + + final float score( + byte[] a, + byte[] b, + float centroidDP, + OptimizedScalarQuantizer.QuantizationResult aCorrection, + OptimizedScalarQuantizer.QuantizationResult bCorrection + ) { + float ax = aCorrection.lowerInterval(); + float lx = (aCorrection.upperInterval() - ax) * LIMIT_SCALE; + float ay = bCorrection.lowerInterval(); + float ly = (bCorrection.upperInterval() - ay) * LIMIT_SCALE; + float y1 = bCorrection.quantizedComponentSum(); + float x1 = aCorrection.quantizedComponentSum(); + float score = ax * ay * a.length + ay * lx * x1 + ax * ly * y1 + lx * ly * VectorUtil.dotProduct(a, b); + return scaleScore(score, aCorrection.additionalCorrection(), bCorrection.additionalCorrection(), centroidDP); + } + + abstract float scaleScore(float score, float aCorrection, float bCorrection, float centroidDP); + + private static class DotProductOSQScorer extends OSQScorer { + @Override + float scaleScore(float score, float aCorrection, float bCorrection, float centroidDP) { + score += aCorrection + bCorrection - centroidDP; + score = Math.clamp(score, -1, 1); + return VectorUtil.normalizeToUnitInterval(score); + } + } + + private static class MaxInnerProductOSQScorer extends OSQScorer { + @Override + float scaleScore(float score, float aCorrection, float bCorrection, float centroidDP) { + score += aCorrection + bCorrection - centroidDP; + return VectorUtil.scaleMaxInnerProductScore(score); + } + } + + private static class EuclideanOSQScorer extends OSQScorer { + @Override + float scaleScore(float score, float aCorrection, float bCorrection, float centroidDP) { + score = aCorrection + bCorrection - 2 * score; + return VectorUtil.normalizeDistanceToUnitInterval(Math.max(score, 0f)); + } + } + } + + static void assertFloatArrayEquals(float[] expected, float[] actual, float delta) { + assertThat(actual.length, equalTo(expected.length)); + for (int i = 0; i < expected.length; i++) { + assertEquals("differed at element [" + i + "]", expected[i], actual[i], Math.abs(expected[i]) * delta + delta); + } + } + + static void assertFloatEquals(float expected, float actual, float delta) { + assertEquals(expected, actual, Math.abs(expected) * delta + delta); + } + + static RandomVectorScorerSupplier luceneScoreSupplier(QuantizedByteVectorValues values, VectorSimilarityFunction sim) + throws IOException { + return new Lucene104ScalarQuantizedVectorScorer(null).getRandomVectorScorerSupplier(sim, values); + } + + // creates the vector based on the given ordinal, which is reproducible given the ord and dims + static byte[] vector(int ord, int dims) { + var random = new Random(Objects.hash(ord, dims)); + byte[] ba = new byte[dims]; + for (int i = 0; i < dims; i++) { + ba[i] = (byte) RandomNumbers.randomIntBetween(random, MIN_INT7_VALUE, MAX_INT7_VALUE); + } + return ba; + } + + static IntFunction FLOAT_ARRAY_RANDOM_FUNC = size -> { + float[] fa = new float[size]; + for (int i = 0; i < size; i++) { + fa[i] = randomFloat(); + } + return fa; + }; + + static IntFunction BYTE_ARRAY_RANDOM_INT7_FUNC = size -> { + byte[] ba = new byte[size]; + randomBytesBetween(ba, MIN_INT7_VALUE, MAX_INT7_VALUE); + return ba; + }; + + static IntFunction BYTE_ARRAY_MAX_INT7_FUNC = size -> { + byte[] ba = new byte[size]; + Arrays.fill(ba, MAX_INT7_VALUE); + return ba; + }; + + static IntFunction BYTE_ARRAY_MIN_INT7_FUNC = size -> { + byte[] ba = new byte[size]; + Arrays.fill(ba, MIN_INT7_VALUE); + return ba; + }; + + static final int TIMES = 100; // a loop iteration times + + static class DenseOffHeapScalarQuantizedVectorValues extends QuantizedByteVectorValues { + final int dimension; + final int size; + final VectorSimilarityFunction similarityFunction; + + final IndexInput slice; + final byte[] vectorValue; + final ByteBuffer byteBuffer; + final int byteSize; + private int lastOrd = -1; + final float[] correctiveValues; + int quantizedComponentSum; + final float[] centroid; + final float centroidDp; + + DenseOffHeapScalarQuantizedVectorValues( + int dimension, + int size, + VectorSimilarityFunction similarityFunction, + IndexInput slice, + float[] centroid, + float centroidDp + ) { + this.dimension = dimension; + this.size = size; + this.similarityFunction = similarityFunction; + this.slice = slice; + this.centroid = centroid; + this.centroidDp = centroidDp; + this.correctiveValues = new float[3]; + this.byteSize = dimension + (Float.BYTES * 3) + Integer.BYTES; + this.byteBuffer = ByteBuffer.allocate(dimension); + this.vectorValue = byteBuffer.array(); + } + + @Override + public IndexInput getSlice() { + return slice; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int vectorOrd) throws IOException { + if (lastOrd != vectorOrd) { + slice.seek((long) vectorOrd * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), vectorValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = slice.readInt(); + lastOrd = vectorOrd; + } + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + return scalarQuantizer(similarityFunction); + } + + @Override + public Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding getScalarEncoding() { + return Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SEVEN_BIT; + } + + @Override + public float[] getCentroid() throws IOException { + return centroid; + } + + @Override + public float getCentroidDP() throws IOException { + return centroidDp; + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + assert false; + return null; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + if (lastOrd == ord) { + return vectorValue; + } + slice.seek((long) ord * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), vectorValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = slice.readInt(); + lastOrd = ord; + return vectorValue; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public QuantizedByteVectorValues copy() throws IOException { + return new DenseOffHeapScalarQuantizedVectorValues(dimension, size, similarityFunction, slice.clone(), centroid, centroidDp); + } + } +} diff --git a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/TransportDataStreamsStatsAction.java b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/TransportDataStreamsStatsAction.java index 0ff562d47b4fe..7f69b5ea19353 100644 --- a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/TransportDataStreamsStatsAction.java +++ b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/TransportDataStreamsStatsAction.java @@ -142,7 +142,7 @@ private static long getMaxTimestamp(IndexShard indexShard) throws IOException { return LongPoint.decodeDimension(maxPackedValue, 0); } // DocValuesSkipper.globalMaxValue() can return a negative number - return Math.max(0, DocValuesSkipper.globalMaxValue(searcher, DataStream.TIMESTAMP_FIELD_NAME)); + return Math.max(0, DocValuesSkipper.globalMaxValue(indexReader, DataStream.TIMESTAMP_FIELD_NAME)); } } diff --git a/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/DenseVectorRollingUpgradeIT.java b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/DenseVectorRollingUpgradeIT.java index ed758f3cacca5..8fccb916396cb 100644 --- a/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/DenseVectorRollingUpgradeIT.java +++ b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/DenseVectorRollingUpgradeIT.java @@ -13,6 +13,7 @@ import org.apache.http.util.EntityUtils; import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseException; import org.elasticsearch.common.Strings; import org.elasticsearch.index.mapper.MapperFeatures; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; @@ -165,32 +166,14 @@ public void testDenseVectorIndexOverUpgrade() throws IOException { if (dims.isEmpty()) { continue; } - Request createIndex = new Request("PUT", "/" + indexName(i, elementType, directIO)); - - XContentBuilder payload = XContentBuilder.builder(XContentType.JSON.xContent()).startObject(); - if (useSyntheticSource) { - payload.startObject("settings").field("index.mapping.source.mode", "synthetic").endObject(); - } - payload.startObject("mappings"); - payload.startObject("properties") - .startObject("embedding") - .field("type", "dense_vector") - .field("element_type", elementType) - .field("index", i.index()) - .field("dims", elementType == ElementType.BIT ? dims.getAsInt() * 8 : dims.getAsInt()); - if (i.index()) { - payload.field("similarity", "l2_norm"); - } - if (i.type() != null) { - payload.startObject("index_options").field("type", i.type()); - if (directIO) { - payload.field("on_disk_rescore", true); - } - payload.endObject(); - } - payload.endObject().endObject().endObject().endObject(); - createIndex.setJsonEntity(Strings.toString(payload)); - client().performRequest(createIndex); + createDenseVectorIndex( + indexName(i, elementType, directIO), + i, + elementType, + dims.getAsInt(), + directIO, + useSyntheticSource + ); } } } @@ -204,13 +187,37 @@ public void testDenseVectorIndexOverUpgrade() throws IOException { continue; } String indexName = indexName(i, elementType, directIO); + // Allow reproducing a single upgraded phase in isolation. + boolean indexExisted = indexExists(indexName); + if (indexExisted == false) { + try { + assertBusy(() -> { + try { + createDenseVectorIndex(indexName, i, elementType, dims.getAsInt(), directIO, false); + } catch (ResponseException e) { + if (e.getResponse().getStatusLine().getStatusCode() == 400 + && EntityUtils.toString(e.getResponse().getEntity(), StandardCharsets.UTF_8) + .contains("resource_already_exists_exception")) { + return; + } + throw e; + } + }); + } catch (Exception e) { + if (e instanceof IOException ioException) { + throw ioException; + } + throw new RuntimeException(e); + } + } + int existingCount = indexExisted ? readCount(indexName) : 0; Request index = new Request("POST", "/" + indexName + "/_bulk/"); index.addParameter("refresh", "true"); index.setJsonEntity(generateBulkData(upgradedNodes, dims.getAsInt())); assertOK(client().performRequest(index)); - int count = (upgradedNodes + 1) * 10; + int count = existingCount + 10; assertCount(indexName, count); checkQuery(indexName, dims.getAsInt(), count); if (i.index()) { @@ -242,16 +249,53 @@ private static OptionalInt getDimensions(String type, ElementType elementType, b return OptionalInt.of(8); } + private void createDenseVectorIndex( + String indexName, + Index indexConfig, + ElementType elementType, + int dims, + boolean directIO, + boolean useSyntheticSource + ) throws IOException { + Request createIndex = new Request("PUT", "/" + indexName); + XContentBuilder payload = XContentBuilder.builder(XContentType.JSON.xContent()).startObject(); + if (useSyntheticSource) { + payload.startObject("settings").field("index.mapping.source.mode", "synthetic").endObject(); + } + payload.startObject("mappings"); + payload.startObject("properties") + .startObject("embedding") + .field("type", "dense_vector") + .field("element_type", elementType) + .field("index", indexConfig.index()) + .field("dims", elementType == ElementType.BIT ? dims * 8 : dims); + if (indexConfig.index()) { + payload.field("similarity", "l2_norm"); + } + if (indexConfig.type() != null) { + payload.startObject("index_options").field("type", indexConfig.type()); + if (directIO) { + payload.field("on_disk_rescore", true); + } + payload.endObject(); + } + payload.endObject().endObject().endObject().endObject(); + createIndex.setJsonEntity(Strings.toString(payload)); + assertOK(client().performRequest(createIndex)); + } + private void assertCount(String index, int count) throws IOException { + assertEquals("Failed on index " + index, count, readCount(index)); + } + + private int readCount(String index) throws IOException { Request request = new Request("POST", "/" + index + "/_search"); request.addParameter(TOTAL_HITS_AS_INT_PARAM, "true"); request.addParameter("filter_path", "hits.total"); Response searchTestIndexResponse = client().performRequest(request); - assertEquals( - "Failed on index " + index, - "{\"hits\":{\"total\":" + count + "}}", - EntityUtils.toString(searchTestIndexResponse.getEntity(), StandardCharsets.UTF_8) - ); + String body = EntityUtils.toString(searchTestIndexResponse.getEntity(), StandardCharsets.UTF_8); + int marker = body.lastIndexOf(':'); + return Integer.parseInt(body.substring(marker + 1).replaceAll("[^0-9]", "")); } private void checkQuery(String index, int dims, int expected) throws IOException { diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index 7acf54a3f0b5a..fb4e9f4c897c7 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -15,7 +15,7 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; -import org.apache.lucene.codecs.lucene103.Lucene103Codec; +import org.apache.lucene.codecs.lucene104.Lucene104Codec; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LogByteSizeMergePolicy; @@ -40,9 +40,9 @@ import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93FlatVectorFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.es93.ES93HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat; -import org.elasticsearch.index.codec.vectors.es93.ES93ScalarQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es94.ES94HnswScalarQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es94.ES94ScalarQuantizedVectorsFormat; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.logging.Level; import org.elasticsearch.logging.LogManager; @@ -246,37 +246,44 @@ yield new ESNextDiskBBQVectorsFormat( ); }; case HNSW -> switch (quantizeBits) { - case null -> new ES93HnswVectorsFormat(args.hnswM(), args.hnswEfConstruction(), elementType, mergeWorkers, exec); + case null -> new ES93HnswVectorsFormat( + args.hnswM(), + args.hnswEfConstruction(), + elementType, + mergeWorkers, + exec, + args.flatVectorThreshold() + ); case 1 -> new ES93HnswBinaryQuantizedVectorsFormat( args.hnswM(), args.hnswEfConstruction(), elementType, false, mergeWorkers, - exec + exec, + args.flatVectorThreshold() ); - default -> new ES93HnswScalarQuantizedVectorsFormat( + default -> new ES94HnswScalarQuantizedVectorsFormat( args.hnswM(), args.hnswEfConstruction(), elementType, - null, quantizeBits, - true, false, mergeWorkers, - exec + exec, + args.flatVectorThreshold() ); }; case FLAT -> switch (quantizeBits) { case null -> new ES93FlatVectorFormat(elementType); case 1 -> new ES93BinaryQuantizedVectorsFormat(elementType, false); - default -> new ES93ScalarQuantizedVectorsFormat(elementType, null, quantizeBits, true, false); + default -> new ES94ScalarQuantizedVectorsFormat(elementType, quantizeBits, false); }; }; logger.info("Using format {}", format.getName()); - return new Lucene103Codec() { + return new Lucene104Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { return new KnnVectorsFormat(format.getName()) { diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java index 4fd97b379655a..24bdfdc68bb54 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/TestConfiguration.java @@ -311,6 +311,7 @@ static class Builder implements ToXContentObject { private int numMergeWorkers = 1; private int flatVectorThreshold = -1; // -1 mean use default (vectorPerCluster * 3) private int secondaryClusterSize = -1; + private int flatIndexThreshold = -1; // use format's default threshold private String directoryType = "default"; /** @@ -403,6 +404,11 @@ public Builder setHnswEfConstruction(int hnswEfConstruction) { return this; } + public Builder setFlatIndexThreshold(int flatIndexThreshold) { + this.flatIndexThreshold = flatIndexThreshold; + return this; + } + public Builder setSearchThreads(List searchThreads) { this.searchThreads = searchThreads; return this; diff --git a/rest-api-spec/build.gradle b/rest-api-spec/build.gradle index a9b439334be9f..cff9e957eb486 100644 --- a/rest-api-spec/build.gradle +++ b/rest-api-spec/build.gradle @@ -59,6 +59,7 @@ tasks.named("yamlRestCompatTestTransform").configure ({ task -> task.replaceValueInMatch("profile.shards.0.dfs.knn.0.query.0.description", "DocAndScoreQuery[0,...][0.009673266,...],0.009673266", "dfs knn vector profiling with vector_operations_count") task.replaceValueInMatch("profile.shards.0.dfs.knn.0.collector.0.name", "TopScoreDocCollector", "dfs knn vector profiling") task.replaceValueInMatch("profile.shards.0.dfs.knn.0.collector.0.name", "TopScoreDocCollector", "dfs knn vector profiling with vector_operations_count") + task.replaceValueInMatch("test_index.mappings.properties.embedding.index_options.confidence_interval", null) task.skipTest("cat.aliases/10_basic/Deprecated local parameter", "CAT APIs not covered by compatibility policy") task.skipTest("cat.shards/10_basic/Help", "sync_id is removed in 9.0") task.skipTest("search/330_fetch_fields/Test with subobjects: auto", "subobjects auto removed") @@ -113,10 +114,28 @@ tasks.named("yamlRestCompatTestTransform").configure ({ task -> task.skipTest("tsdb/10_settings/set start_time and end_time without timeseries mode", "we don't validate for index_mode=tsdb when setting start_date/end_date anymore") task.skipTest("tsdb/10_settings/set start_time, end_time and routing_path via put settings api without time_series mode", "we don't validate for index_mode=tsdb when setting start_date/end_date anymore") task.skipTest("search/140_pre_filter_search_shards/prefilter on non-indexed date fields", "prefiltering can now use skippers on dv-only fields") + task.skipTest("search.vectors/42_knn_search_bbq_flat/Vector rescoring has same scoring as exact search for kNN section", "rescoring is now bulk, floating point values are slightly different") // Expected deprecation warning to compat yaml tests: task.addAllowedWarningRegex("Use of the \\[max_size\\] rollover condition has been deprecated in favour of the \\[max_primary_shard_size\\] condition and will be removed in a later version") + task.addAllowedWarningRegex( + "Parameter \\[confidence_interval\\] in \\[index_options\\] for dense_vector field \\[[^\\]]+\\] is deprecated and will be removed in a future version" + ) task.skipTest("search.vectors/42_knn_search_bbq_flat/Vector rescoring has same scoring as exact search for kNN section", "scores have changed slightly with native implementations") task.skipTest("search.vectors/41_knn_search_bbq_hnsw/Vector rescoring has same scoring as exact search for kNN section", "scores have changed slightly with native implementations") + task.skipTest("search.vectors/41_knn_search_byte_quantized/KNN Vector similarity search only", "min similarity score changed with Lucene 10.4") + task.skipTest("search.vectors/42_knn_search_int8_flat/KNN Vector similarity search only", "min similarity score changed with Lucene 10.4") + task.skipTest( + "search.vectors/180_update_dense_vector_type/Allowed dense vector updates on same type but different other index_options, int8_hnsw", + "confidence_interval handling differs between old/new versions during compat runs" + ) + task.skipTest( + "search.vectors/180_update_dense_vector_type/Allowed dense vector updates on same type but different other index_options, int8_flat", + "confidence_interval handling differs between old/new versions during compat runs" + ) + task.skipTest( + "search.vectors/180_update_dense_vector_type/Allowed dense vector updates on same type but different other index_options, int4_flat", + "confidence_interval handling differs between old/new versions during compat runs" + ) task.skipTest( "get/100_synthetic_source/fields with ignore_malformed", "Malformed values are now stored in binary doc values which sort differently than stored fields" diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/result-diversification/10_mmr_result_diversification_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/result-diversification/10_mmr_result_diversification_retriever.yml index 5022ca8bcd7ee..663290a3e27e3 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/result-diversification/10_mmr_result_diversification_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/result-diversification/10_mmr_result_diversification_retriever.yml @@ -2,6 +2,9 @@ setup: - requires: cluster_features: [ "retriever.result_diversification_mmr" ] reason: "Added retriever for result diversification using MMR" + - requires: + cluster_features: [ "lucene_10_4_upgrade" ] + reason: "MMR result ordering changed with Lucene 10.4" - requires: test_runner_features: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/180_update_dense_vector_type.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/180_update_dense_vector_type.yml index db4cdc5457612..f2e7b2562a8bf 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/180_update_dense_vector_type.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/180_update_dense_vector_type.yml @@ -1442,7 +1442,6 @@ setup: type: int8_hnsw m: 32 ef_construction: 200 - confidence_interval: 0.3 - do: indices.get_mapping: @@ -1452,7 +1451,6 @@ setup: - match: { test_index.mappings.properties.embedding.index_options.type: int8_hnsw } - match: { test_index.mappings.properties.embedding.index_options.m: 32 } - match: { test_index.mappings.properties.embedding.index_options.ef_construction: 200 } - - match: { test_index.mappings.properties.embedding.index_options.confidence_interval: 0.3 } - do: catch: /illegal_argument_exception/ # fails because m = 10 is less than the current value of 20 @@ -1467,7 +1465,6 @@ setup: type: int8_hnsw ef_construction: 200 m: 10 - confidence_interval: 0.3 - do: catch: /illegal_argument_exception/ # fails because m = 16 by default, which is less than the current value of 20 @@ -1481,7 +1478,6 @@ setup: index_options: type: int8_hnsw ef_construction: 200 - confidence_interval: 0.3 --- "Allowed dense vector updates on same type but different other index_options, int4_hnsw": @@ -1614,7 +1610,6 @@ setup: dims: 4 index_options: type: int8_flat - confidence_interval: 0.3 - do: indices.get_mapping: @@ -1622,7 +1617,6 @@ setup: - match: { test_index.mappings.properties.embedding.type: dense_vector } - match: { test_index.mappings.properties.embedding.index_options.type: int8_flat } - - match: { test_index.mappings.properties.embedding.index_options.confidence_interval: 0.3 } --- "Allowed dense vector updates on same type but different other index_options, int4_flat": @@ -1663,7 +1657,6 @@ setup: dims: 4 index_options: type: int4_flat - confidence_interval: 0.3 - do: indices.get_mapping: @@ -1671,7 +1664,6 @@ setup: - match: { test_index.mappings.properties.embedding.type: dense_vector } - match: { test_index.mappings.properties.embedding.index_options.type: int4_flat } - - match: { test_index.mappings.properties.embedding.index_options.confidence_interval: 0.3 } --- "Test create and update dense vector mapping to int4 with per-doc indexing and flush": diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/220_dense_vector_node_index_stats.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/220_dense_vector_node_index_stats.yml index 27e7b3ae55e25..20a3a42ceeeef 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/220_dense_vector_node_index_stats.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/220_dense_vector_node_index_stats.yml @@ -6,6 +6,9 @@ capabilities: [ dense_vector_off_heap_stats ] test_runner_features: [ capabilities ] reason: Capability required to run test + - requires: + cluster_features: [ "search.vectors.flat_index_threshold" ] + reason: "Uses index_options.flat_index_threshold" - do: indices.create: @@ -25,6 +28,7 @@ similarity: l2_norm index_options: type: hnsw + flat_index_threshold: 0 vector2: type: dense_vector "element_type": "float" @@ -41,6 +45,7 @@ similarity: l2_norm index_options: type: hnsw + flat_index_threshold: 0 vector4: type: dense_vector "element_type": "byte" @@ -167,6 +172,9 @@ capabilities: [ dense_vector_off_heap_stats ] test_runner_features: [ capabilities ] reason: Capability required to run test + - requires: + cluster_features: [ "search.vectors.flat_index_threshold" ] + reason: "Uses index_options.flat_index_threshold" - do: indices.create: @@ -185,6 +193,7 @@ similarity: l2_norm index_options: type: int8_hnsw + flat_index_threshold: 0 vector2: type: dense_vector dims: 4 @@ -199,6 +208,7 @@ similarity: l2_norm index_options: type: int4_hnsw + flat_index_threshold: 0 vector4: type: dense_vector dims: 4 @@ -322,6 +332,9 @@ capabilities: [ dense_vector_off_heap_stats ] test_runner_features: [ capabilities ] reason: Capability required to run test + - requires: + cluster_features: [ "search.vectors.flat_index_threshold" ] + reason: "Uses index_options.flat_index_threshold" - requires: capabilities: - method: POST @@ -347,6 +360,7 @@ similarity: l2_norm index_options: type: bbq_hnsw + flat_index_threshold: 0 vector2: type: dense_vector dims: 64 @@ -467,6 +481,9 @@ capabilities: [ dense_vector_off_heap_stats ] test_runner_features: [ capabilities ] reason: Capability required to run test + - requires: + cluster_features: [ "search.vectors.flat_index_threshold" ] + reason: "Uses index_options.flat_index_threshold" - do: indices.create: @@ -486,6 +503,7 @@ similarity: l2_norm index_options: type: hnsw + flat_index_threshold: 0 vector2: type: dense_vector "element_type": "bit" @@ -581,6 +599,9 @@ capabilities: [ dense_vector_off_heap_stats ] test_runner_features: [ capabilities ] reason: Capability required to run test + - requires: + cluster_features: [ "search.vectors.flat_index_threshold" ] + reason: "Uses index_options.flat_index_threshold" - do: indices.create: @@ -599,6 +620,7 @@ similarity: l2_norm index_options: type: int8_hnsw + flat_index_threshold: 0 vector2: type: dense_vector dims: 5 @@ -606,6 +628,7 @@ similarity: l2_norm index_options: type: int8_hnsw + flat_index_threshold: 0 - do: nodes.stats: @@ -653,6 +676,9 @@ capabilities: [ dense_vector_off_heap_stats ] test_runner_features: [ capabilities ] reason: Capability required to run test + - requires: + cluster_features: [ "search.vectors.flat_index_threshold" ] + reason: "Uses index_options.flat_index_threshold" - requires: capabilities: - method: POST @@ -678,6 +704,7 @@ similarity: l2_norm index_options: type: bbq_hnsw + flat_index_threshold: 0 - do: indices.create: @@ -713,6 +740,7 @@ similarity: l2_norm index_options: type: hnsw + flat_index_threshold: 0 - do: index: @@ -837,6 +865,9 @@ capabilities: [ dense_vector_off_heap_stats ] test_runner_features: [ capabilities ] reason: Capability required to run test + - requires: + cluster_features: [ "search.vectors.flat_index_threshold" ] + reason: "Uses index_options.flat_index_threshold" - do: indices.create: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml index b297770eb9104..0b805d50a741a 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml @@ -194,6 +194,9 @@ setup: --- "KNN Vector similarity search only": + - requires: + cluster_features: [ "lucene_10_4_upgrade" ] + reason: "min similarity score changed with Lucene 10.4" - do: search: index: hnsw_byte_quantized @@ -203,7 +206,7 @@ setup: num_candidates: 3 k: 3 field: vector - similarity: 10.3 + similarity: 10.51 query_vector: [-0.5, 90.0, -10, 14.8, -156.0] - length: {hits.hits: 1} diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized_bfloat16.yml index 80e619a9ad318..bf6436b5bb7d7 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized_bfloat16.yml @@ -205,7 +205,7 @@ setup: num_candidates: 3 k: 3 field: vector - similarity: 10.3 + similarity: 10.51 query_vector: [-0.5, 90.0, -10, 14.8, -156.0] - length: {hits.hits: 1} diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index 344852df0221b..cc7eb606b7952 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -113,7 +113,9 @@ setup: path: /_search capabilities: [knn_quantized_vector_rescore_oversample] - skip: - features: "headers" + features: + - headers + - close_to # Rescore - do: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml index 4e0c9426c80b7..810d29b0e0c42 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml @@ -20,7 +20,6 @@ setup: similarity: l2_norm index_options: type: int8_flat - confidence_interval: 0.9 another_vector: type: dense_vector dims: 5 @@ -28,7 +27,6 @@ setup: similarity: l2_norm index_options: type: int8_flat - confidence_interval: 0.9 - do: index: @@ -180,6 +178,9 @@ setup: --- "KNN Vector similarity search only": + - requires: + cluster_features: [ "lucene_10_4_upgrade" ] + reason: "min similarity score changed with Lucene 10.4" - do: search: index: int8_flat @@ -189,7 +190,7 @@ setup: num_candidates: 3 k: 3 field: vector - similarity: 10.3 + similarity: 10.51 query_vector: [-0.5, 90.0, -10, 14.8, -156.0] - length: {hits.hits: 1} diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat_bfloat16.yml index 2eb00213944f2..7c77fc6114bc9 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat_bfloat16.yml @@ -21,7 +21,6 @@ setup: similarity: l2_norm index_options: type: int8_flat - confidence_interval: 0.9 another_vector: type: dense_vector element_type: bfloat16 @@ -30,7 +29,6 @@ setup: similarity: l2_norm index_options: type: int8_flat - confidence_interval: 0.9 - do: index: @@ -182,6 +180,9 @@ setup: --- "KNN Vector similarity search only": + - requires: + cluster_features: [ "lucene_10_4_upgrade" ] + reason: "min similarity score changed with Lucene 10.4" - do: search: index: int8_flat @@ -191,7 +192,7 @@ setup: num_candidates: 3 k: 3 field: vector - similarity: 10.3 + similarity: 10.51 query_vector: [-0.5, 90.0, -10, 14.8, -156.0] - length: {hits.hits: 1} diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/370_profile.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/370_profile.yml index c1fdb8adc8ee9..1317ca4c770d7 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/370_profile.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/370_profile.yml @@ -236,8 +236,8 @@ dfs knn vector profiling: --- dfs knn vector profiling description: - requires: - cluster_features: ["lucene_10_upgrade"] - reason: "the profile description changed with Lucene 10" + cluster_features: ["lucene_10_4_upgrade"] + reason: "the profile description changed with Lucene 10.4" - do: indices.create: index: images @@ -272,7 +272,7 @@ dfs knn vector profiling description: num_candidates: 100 - match: { hits.total.value: 1 } - - match: { profile.shards.0.dfs.knn.0.query.0.description: "DocAndScoreQuery[0,...][0.009673266,...],0.009673266" } + - match: { profile.shards.0.dfs.knn.0.query.0.description: "DocAndScoreQuery[0,...][0.008547009,...],0.008547009" } --- dfs knn vector profiling collector name: diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/vectors/HnswGraphThresholdIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/vectors/HnswGraphThresholdIT.java new file mode 100644 index 0000000000000..3b9d65b39bf45 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/vectors/HnswGraphThresholdIT.java @@ -0,0 +1,360 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.mapper.vectors; + +import org.apache.lucene.tests.util.LuceneTestCase; +import org.elasticsearch.action.admin.indices.stats.IndicesStatsResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.shard.DenseVectorStats; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat.BBQ_HNSW_GRAPH_THRESHOLD; +import static org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.notNullValue; + +/** + * Integration tests for HNSW graph threshold setting. + * Tests that the graph is conditionally built based on the expected search cost threshold. + * The threshold represents the minimum expected search cost before building an HNSW graph becomes worthwhile. + */ +@LuceneTestCase.SuppressCodecs("*") // only use our own codecs to ensure HNSW threshold is applied +public class HnswGraphThresholdIT extends ESIntegTestCase { + + private static final String INDEX_NAME = "hnsw_threshold_test"; + private static final String VECTOR_FIELD = "vector"; + private static final int DIMENSIONS = 64; + + // Number of vectors needed to exceed the threshold (based on search power calculation: + // graph is built when numVectors > log(numVectors) * threshold) + private static final int BBQ_HNSW_VECTORS_FOR_GRAPH = 2327; + private static final int HNSW_VECTORS_FOR_GRAPH = 1045; + + /** + * Tests that with default threshold, graph is NOT built for small vector counts, + * but IS built when vectors exceed the search power threshold. + */ + public void testGraphThresholdWithDefaultSettings() throws Exception { + IndexTypeConfig config = randomIndexTypeConfig(); + logger.info("Testing with index type: {}, element type: {}", config.indexType, config.elementType); + + // Create index with default threshold + assertAcked( + prepareCreate(INDEX_NAME).setSettings( + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + ).setMapping(createMapping(config.indexType, config.elementType)) + ); + ensureGreen(INDEX_NAME); + + // Index vectors below the threshold + int smallDocCount = 100; + indexVectors(smallDocCount, 0, config.elementType); + flushAndRefresh(INDEX_NAME); + + // Verify: with small vector count, graph should NOT be built + DenseVectorStats statsBeforeThreshold = getDenseVectorStats(); + assertThat("Should have indexed vectors", statsBeforeThreshold.getValueCount(), equalTo((long) smallDocCount)); + Long vexSizeSmall = getVexSize(statsBeforeThreshold); + assertTrue( + "Graph should NOT be built with " + smallDocCount + " vectors (below threshold for " + config.indexType + ")", + vexSizeSmall == null || vexSizeSmall == 0L + ); + + // Index more vectors to exceed the threshold + int additionalDocs = config.vectorsForGraph; + indexVectors(additionalDocs, smallDocCount, config.elementType); + flushAndRefresh(INDEX_NAME); + + // Force merge to ensure all vectors are in one segment + forceMergeIndex(); + + // Verify: with vectors above threshold, graph SHOULD be built + DenseVectorStats statsAfterThreshold = getDenseVectorStats(); + int totalDocs = smallDocCount + additionalDocs; + assertThat("Should have indexed all vectors", statsAfterThreshold.getValueCount(), equalTo((long) totalDocs)); + Long vexSizeLarge = getVexSize(statsAfterThreshold); + assertThat( + "Graph SHOULD be built with " + totalDocs + " vectors (above threshold for " + config.indexType + ")", + vexSizeLarge, + notNullValue() + ); + assertThat("Graph size should be positive", vexSizeLarge, greaterThan(0L)); + } + + /** + * Tests that setting flat_index_threshold=0 forces graph to always be built, even with few vectors. + */ + public void testGraphAlwaysBuiltWithThresholdZero() throws Exception { + IndexTypeConfig config = randomIndexTypeConfig(); + logger.info("Testing flat_index_threshold=0 with index type: {}, element type: {}", config.indexType, config.elementType); + + // Create index with flat_index_threshold=0 in mapping (always build graph) + assertAcked( + prepareCreate(INDEX_NAME).setSettings( + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + ).setMapping(createMappingWithThreshold(config.indexType, config.elementType, 0)) + ); + ensureGreen(INDEX_NAME); + + // Index small number of vectors + int smallDocCount = 10; + indexVectors(smallDocCount, 0, config.elementType); + flushAndRefresh(INDEX_NAME); + + // Verify: with flat_index_threshold=0, graph SHOULD be built even with few vectors + DenseVectorStats stats = getDenseVectorStats(); + assertThat("Should have indexed vectors", stats.getValueCount(), equalTo((long) smallDocCount)); + Long vexSize = getVexSize(stats); + assertThat("Graph SHOULD be built with flat_index_threshold=0", vexSize, notNullValue()); + assertThat("Graph size should be positive", vexSize, greaterThan(0L)); + } + + /** + * Tests that updating flat_index_threshold from default to 0 causes graph to be built on new segments. + */ + public void testMappingUpdateToThresholdZero() throws Exception { + IndexTypeConfig config = randomIndexTypeConfig(); + logger.info("Testing mapping update with index type: {}, element type: {}", config.indexType, config.elementType); + + // Create index with default threshold (no flat_index_threshold specified) + assertAcked( + prepareCreate(INDEX_NAME).setSettings( + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + ).setMapping(createMapping(config.indexType, config.elementType)) + ); + ensureGreen(INDEX_NAME); + + // Index small number of vectors - graph should NOT be built with default threshold + int initialDocCount = 50; + indexVectors(initialDocCount, 0, config.elementType); + flushAndRefresh(INDEX_NAME); + + DenseVectorStats statsBeforeUpdate = getDenseVectorStats(); + assertThat("Should have indexed vectors", statsBeforeUpdate.getValueCount(), equalTo((long) initialDocCount)); + Long vexSizeBeforeUpdate = getVexSize(statsBeforeUpdate); + assertTrue( + "Graph should NOT be built with default threshold and few vectors", + vexSizeBeforeUpdate == null || vexSizeBeforeUpdate == 0L + ); + + // Update mapping to set flat_index_threshold=0 (always build graph) + assertAcked( + indicesAdmin().preparePutMapping(INDEX_NAME).setSource(createMappingWithThreshold(config.indexType, config.elementType, 0)) + ); + + // Index more vectors - these will go to a new segment with updated threshold + int additionalDocCount = 50; + indexVectors(additionalDocCount, initialDocCount, config.elementType); + flushAndRefresh(INDEX_NAME); + + // Verify: new segment should have graph built with threshold=0 + DenseVectorStats statsAfterUpdate = getDenseVectorStats(); + int totalDocs = initialDocCount + additionalDocCount; + assertThat("Should have indexed all vectors", statsAfterUpdate.getValueCount(), equalTo((long) totalDocs)); + Long vexSizeAfterUpdate = getVexSize(statsAfterUpdate); + assertThat("Graph SHOULD be built after updating flat_index_threshold to 0", vexSizeAfterUpdate, notNullValue()); + assertThat("Graph size should be positive", vexSizeAfterUpdate, greaterThan(0L)); + } + + /** + * Tests that multiple fields can have different flat_index_threshold values. + * One field with threshold=0 should have graph built, another with default threshold should not. + */ + public void testMultipleFields() throws Exception { + IndexTypeConfig config = randomIndexTypeConfig(); + logger.info("Testing multiple fields with index type: {}, element type: {}", config.indexType, config.elementType); + + String fieldWithThresholdZero = "vector_with_graph"; + String fieldWithDefaultThreshold = "vector_without_graph"; + + // Create index with two fields: one with threshold=0, one with default + XContentBuilder mapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldWithThresholdZero) + .field("type", "dense_vector") + .field("dims", DIMENSIONS) + .field("index", true) + .field("similarity", "l2_norm") + .field("element_type", config.elementType) + .startObject("index_options") + .field("type", config.indexType) + .field("flat_index_threshold", 0) + .endObject() + .endObject() + .startObject(fieldWithDefaultThreshold) + .field("type", "dense_vector") + .field("dims", DIMENSIONS) + .field("index", true) + .field("similarity", "l2_norm") + .field("element_type", config.elementType) + .startObject("index_options") + .field("type", config.indexType) + .endObject() + .endObject() + .endObject() + .endObject(); + + assertAcked( + prepareCreate(INDEX_NAME).setSettings( + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + ).setMapping(mapping) + ); + ensureGreen(INDEX_NAME); + + // Index small number of vectors to both fields + int docCount = 50; + for (int i = 0; i < docCount; i++) { + Object vector = generateVector(config.elementType); + client().index( + new IndexRequest(INDEX_NAME).id(String.valueOf(i)) + .source(Map.of(fieldWithThresholdZero, vector, fieldWithDefaultThreshold, vector)) + ).actionGet(); + } + flushAndRefresh(INDEX_NAME); + + // Verify: field with threshold=0 should have graph, field with default should not + DenseVectorStats stats = getDenseVectorStats(); + + Long vexSizeWithThresholdZero = getVexSize(stats, fieldWithThresholdZero); + assertThat("Graph SHOULD be built for field with flat_index_threshold=0", vexSizeWithThresholdZero, notNullValue()); + assertThat("Graph size should be positive", vexSizeWithThresholdZero, greaterThan(0L)); + + Long vexSizeWithDefault = getVexSize(stats, fieldWithDefaultThreshold); + assertTrue( + "Graph should NOT be built for field with default threshold and few vectors", + vexSizeWithDefault == null || vexSizeWithDefault == 0L + ); + } + + private record IndexTypeConfig(String indexType, String elementType, int threshold, int vectorsForGraph) {} + + private IndexTypeConfig randomIndexTypeConfig() { + String indexType = randomFrom("hnsw", "int8_hnsw", "int4_hnsw", "bbq_hnsw"); + String elementType; + int threshold; + int vectorsForGraph; + + switch (indexType) { + case "bbq_hnsw" -> { + elementType = randomFrom("float", "bfloat16"); + threshold = BBQ_HNSW_GRAPH_THRESHOLD; + vectorsForGraph = BBQ_HNSW_VECTORS_FOR_GRAPH; + } + case "int8_hnsw", "int4_hnsw" -> { + elementType = randomFrom("float", "bfloat16"); + threshold = HNSW_GRAPH_THRESHOLD; + vectorsForGraph = HNSW_VECTORS_FOR_GRAPH; + } + default -> { // hnsw + elementType = randomFrom("float", "byte", "bfloat16"); + threshold = HNSW_GRAPH_THRESHOLD; + vectorsForGraph = HNSW_VECTORS_FOR_GRAPH; + } + } + return new IndexTypeConfig(indexType, elementType, threshold, vectorsForGraph); + } + + private XContentBuilder createMapping(String indexType, String elementType) throws IOException { + return createMappingWithThreshold(indexType, elementType, null); + } + + private XContentBuilder createMappingWithThreshold(String indexType, String elementType, Integer graphThreshold) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(VECTOR_FIELD) + .field("type", "dense_vector") + .field("dims", DIMENSIONS) + .field("index", true) + .field("similarity", "l2_norm") + .field("element_type", elementType) + .startObject("index_options") + .field("type", indexType); + if (graphThreshold != null) { + builder.field("flat_index_threshold", graphThreshold); + } + builder.endObject().endObject().endObject().endObject(); + return builder; + } + + private void indexVectors(int count, int startId, String elementType) { + for (int i = 0; i < count; i++) { + Object vector = generateVector(elementType); + client().index(new IndexRequest(INDEX_NAME).id(String.valueOf(startId + i)).source(Map.of(VECTOR_FIELD, vector))).actionGet(); + } + } + + private Object generateVector(String elementType) { + if ("byte".equals(elementType)) { + List vector = new ArrayList<>(DIMENSIONS); + for (int j = 0; j < DIMENSIONS; j++) { + vector.add(randomIntBetween(-128, 127)); + } + return vector; + } else { + // float, bfloat16 - all use float values in the API + List vector = new ArrayList<>(DIMENSIONS); + for (int j = 0; j < DIMENSIONS; j++) { + vector.add(randomFloat()); + } + return vector; + } + } + + private void flushAndRefresh(String indexName) { + indicesAdmin().prepareFlush(indexName).get(); + indicesAdmin().prepareRefresh(indexName).get(); + } + + private void forceMergeIndex() { + indicesAdmin().prepareForceMerge(INDEX_NAME).setMaxNumSegments(1).get(); + flushAndRefresh(INDEX_NAME); + } + + private DenseVectorStats getDenseVectorStats() { + IndicesStatsResponse statsResponse = indicesAdmin().prepareStats(INDEX_NAME).setDenseVector(true).get(); + return statsResponse.getIndex(INDEX_NAME).getTotal().getDenseVectorStats(); + } + + /** + * Gets the vex (HNSW graph) size for the default vector field. + * Returns null if no graph exists. + */ + private Long getVexSize(DenseVectorStats stats) { + return getVexSize(stats, VECTOR_FIELD); + } + + /** + * Gets the vex (HNSW graph) size for a specific field. + * Returns null if no graph exists. + */ + private Long getVexSize(DenseVectorStats stats, String fieldName) { + if (stats.offHeapStats() == null) { + return null; + } + Map fieldStats = stats.offHeapStats().get(fieldName); + if (fieldStats == null) { + return null; + } + return fieldStats.get("vex"); + } +} diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index f97b0ba45d386..9720958bac81f 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -478,7 +478,9 @@ org.elasticsearch.index.codec.vectors.es93.ES93ScalarQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es93.ES93HnswScalarQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat, - org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; + org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es94.ES94HnswScalarQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es94.ES94ScalarQuantizedVectorsFormat; provides org.apache.lucene.codecs.Codec with @@ -487,6 +489,7 @@ org.elasticsearch.index.codec.Elasticsearch900Codec, org.elasticsearch.index.codec.Elasticsearch900Lucene101Codec, org.elasticsearch.index.codec.Elasticsearch92Lucene103Codec, + org.elasticsearch.index.codec.Elasticsearch93Lucene104Codec, org.elasticsearch.index.codec.tsdb.ES93TSDBDefaultCompressionLucene103Codec; provides org.apache.logging.log4j.core.util.ContextDataProvider with org.elasticsearch.common.logging.DynamicContextDataProvider; @@ -512,6 +515,7 @@ exports org.elasticsearch.index.codec.vectors.diskbbq.next to org.elasticsearch.test.knn, org.elasticsearch.xpack.diskbbq; exports org.elasticsearch.index.codec.vectors.cluster to org.elasticsearch.test.knn; exports org.elasticsearch.index.codec.vectors.es93 to org.elasticsearch.test.knn; + exports org.elasticsearch.index.codec.vectors.es94 to org.elasticsearch.test.knn; exports org.elasticsearch.search.crossproject; exports org.elasticsearch.index.mapper.blockloader; exports org.elasticsearch.index.mapper.blockloader.docvalues; diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzer.java b/server/src/main/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzer.java index 03c806e59b278..7938521cc31e1 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzer.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzer.java @@ -23,7 +23,7 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; -import org.apache.lucene.codecs.lucene103.Lucene103PostingsFormat; +import org.apache.lucene.codecs.lucene104.Lucene104PostingsFormat; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DocValuesType; @@ -326,7 +326,7 @@ private static void readProximity(Terms terms, PostingsEnum postings) throws IOE private static BlockTermState getBlockTermState(TermsEnum termsEnum, BytesRef term) throws IOException { if (term != null && termsEnum.seekExact(term)) { final TermState termState = termsEnum.termState(); - if (termState instanceof final Lucene103PostingsFormat.IntBlockTermState blockTermState) { + if (termState instanceof final Lucene104PostingsFormat.IntBlockTermState blockTermState) { return new BlockTermState(blockTermState.docStartFP, blockTermState.posStartFP, blockTermState.payStartFP); } if (termState instanceof final Lucene101PostingsFormat.IntBlockTermState blockTermState) { diff --git a/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java b/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java index c740d2747cf91..0e87486d60bdd 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java @@ -96,7 +96,7 @@ public class Lucene { - public static final String LATEST_CODEC = "Lucene103"; + public static final String LATEST_CODEC = "Lucene104"; public static final String SOFT_DELETES_FIELD = "__soft_deletes"; diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index 6fcd3bcfb8cfa..84dba83ca1a1e 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -225,6 +225,7 @@ private static Version parseUnchecked(String version) { public static final IndexVersion TIME_SERIES_DOC_VALUES_FORMAT_VERSION_3 = def(9_072_0_00, Version.LUCENE_10_3_2); public static final IndexVersion STORE_IGNORED_MALFORMED_IN_BINARY_DOC_VALUES = def(9_073_0_00, Version.LUCENE_10_3_2); public static final IndexVersion DISABLE_SEQUENCE_NUMBERS = def(9_074_0_00, Version.LUCENE_10_3_2); + public static final IndexVersion UPGRADE_TO_LUCENE_10_4_0 = def(9_075_00_0, Version.LUCENE_10_4_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/index/codec/CodecService.java b/server/src/main/java/org/elasticsearch/index/codec/CodecService.java index 25c5106b47916..7bf77b77fc5c1 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/CodecService.java +++ b/server/src/main/java/org/elasticsearch/index/codec/CodecService.java @@ -12,7 +12,7 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.FieldInfosFormat; import org.apache.lucene.codecs.FilterCodec; -import org.apache.lucene.codecs.lucene103.Lucene103Codec; +import org.apache.lucene.codecs.lucene104.Lucene104Codec; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.FeatureFlag; import org.elasticsearch.core.Nullable; @@ -50,7 +50,7 @@ public CodecService(@Nullable MapperService mapperService, BigArrays bigArrays, boolean useSyntheticId = mapperService != null && mapperService.getIndexSettings().useTimeSeriesSyntheticId(); - var legacyBestSpeedCodec = new LegacyPerFieldMapperCodec(Lucene103Codec.Mode.BEST_SPEED, mapperService, bigArrays, threadPool); + var legacyBestSpeedCodec = new LegacyPerFieldMapperCodec(Lucene104Codec.Mode.BEST_SPEED, mapperService, bigArrays, threadPool); if (useSyntheticId) { // Use the default Lucene compression when the synthetic id is used even if the ZSTD feature flag is enabled codecs.put(DEFAULT_CODEC, new ES93TSDBDefaultCompressionLucene103Codec(legacyBestSpeedCodec)); @@ -69,7 +69,7 @@ public CodecService(@Nullable MapperService mapperService, BigArrays bigArrays, new PerFieldMapperCodec(Zstd814StoredFieldsFormat.Mode.BEST_COMPRESSION, mapperService, bigArrays, threadPool) ); Codec legacyBestCompressionCodec = new LegacyPerFieldMapperCodec( - Lucene103Codec.Mode.BEST_COMPRESSION, + Lucene104Codec.Mode.BEST_COMPRESSION, mapperService, bigArrays, threadPool diff --git a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch92Lucene103Codec.java b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch92Lucene103Codec.java index c26d485fc8c99..f9b47462e2a2c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch92Lucene103Codec.java +++ b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch92Lucene103Codec.java @@ -9,12 +9,12 @@ package org.elasticsearch.index.codec; +import org.apache.lucene.backward_codecs.lucene103.Lucene103Codec; +import org.apache.lucene.backward_codecs.lucene103.Lucene103PostingsFormat; import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.PostingsFormat; import org.apache.lucene.codecs.StoredFieldsFormat; -import org.apache.lucene.codecs.lucene103.Lucene103Codec; -import org.apache.lucene.codecs.lucene103.Lucene103PostingsFormat; import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; diff --git a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch93Lucene104Codec.java b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch93Lucene104Codec.java new file mode 100644 index 0000000000000..5b39b6d7db3b6 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch93Lucene104Codec.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec; + +import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.PostingsFormat; +import org.apache.lucene.codecs.StoredFieldsFormat; +import org.apache.lucene.codecs.lucene104.Lucene104Codec; +import org.apache.lucene.codecs.lucene104.Lucene104PostingsFormat; +import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat; +import org.elasticsearch.index.codec.perfield.XPerFieldDocValuesFormat; +import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; + +/** + * Elasticsearch codec as of 9.3 relying on Lucene 10.4. This extends the Lucene 10.4 codec to compressed + * stored fields with ZSTD instead of LZ4/DEFLATE. See {@link Zstd814StoredFieldsFormat}. + */ +public class Elasticsearch93Lucene104Codec extends CodecService.DeduplicateFieldInfosCodec { + + static final PostingsFormat DEFAULT_POSTINGS_FORMAT = new Lucene104PostingsFormat(); + + private final StoredFieldsFormat storedFieldsFormat; + + private final PostingsFormat defaultPostingsFormat; + private final PostingsFormat postingsFormat = new PerFieldPostingsFormat() { + @Override + public PostingsFormat getPostingsFormatForField(String field) { + return Elasticsearch93Lucene104Codec.this.getPostingsFormatForField(field); + } + }; + + private final DocValuesFormat defaultDVFormat; + private final DocValuesFormat docValuesFormat = new XPerFieldDocValuesFormat() { + @Override + public DocValuesFormat getDocValuesFormatForField(String field) { + return Elasticsearch93Lucene104Codec.this.getDocValuesFormatForField(field); + } + }; + + private final KnnVectorsFormat defaultKnnVectorsFormat; + private final KnnVectorsFormat knnVectorsFormat = new PerFieldKnnVectorsFormat() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return Elasticsearch93Lucene104Codec.this.getKnnVectorsFormatForField(field); + } + }; + + /** Public no-arg constructor, needed for SPI loading at read-time. */ + public Elasticsearch93Lucene104Codec() { + this(Zstd814StoredFieldsFormat.Mode.BEST_SPEED); + } + + /** + * Constructor. Takes a {@link Zstd814StoredFieldsFormat.Mode} that describes whether to optimize for retrieval speed at the expense of + * worse space-efficiency or vice-versa. + */ + public Elasticsearch93Lucene104Codec(Zstd814StoredFieldsFormat.Mode mode) { + super("Elasticsearch92Lucene103", new Lucene104Codec()); + this.storedFieldsFormat = mode.getFormat(); + this.defaultPostingsFormat = DEFAULT_POSTINGS_FORMAT; + this.defaultDVFormat = new Lucene90DocValuesFormat(); + this.defaultKnnVectorsFormat = new Lucene99HnswVectorsFormat(); + } + + @Override + public StoredFieldsFormat storedFieldsFormat() { + return storedFieldsFormat; + } + + @Override + public final PostingsFormat postingsFormat() { + return postingsFormat; + } + + @Override + public final DocValuesFormat docValuesFormat() { + return docValuesFormat; + } + + @Override + public final KnnVectorsFormat knnVectorsFormat() { + return knnVectorsFormat; + } + + /** + * Returns the postings format that should be used for writing new segments of field. + * + *

The default implementation always returns "Lucene912". + * + *

WARNING: if you subclass, you are responsible for index backwards compatibility: + * future version of Lucene are only guaranteed to be able to read the default implementation, + */ + public PostingsFormat getPostingsFormatForField(String field) { + return defaultPostingsFormat; + } + + /** + * Returns the docvalues format that should be used for writing new segments of field + * . + * + *

The default implementation always returns "Lucene912". + * + *

WARNING: if you subclass, you are responsible for index backwards compatibility: + * future version of Lucene are only guaranteed to be able to read the default implementation. + */ + public DocValuesFormat getDocValuesFormatForField(String field) { + return defaultDVFormat; + } + + /** + * Returns the vectors format that should be used for writing new segments of field + * + *

The default implementation always returns "Lucene912". + * + *

WARNING: if you subclass, you are responsible for index backwards compatibility: + * future version of Lucene are only guaranteed to be able to read the default implementation. + */ + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return defaultKnnVectorsFormat; + } + +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/LegacyPerFieldMapperCodec.java b/server/src/main/java/org/elasticsearch/index/codec/LegacyPerFieldMapperCodec.java index 77d5fc540cfca..cf3451c61f9ec 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/LegacyPerFieldMapperCodec.java +++ b/server/src/main/java/org/elasticsearch/index/codec/LegacyPerFieldMapperCodec.java @@ -13,7 +13,7 @@ import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.PostingsFormat; -import org.apache.lucene.codecs.lucene103.Lucene103Codec; +import org.apache.lucene.codecs.lucene104.Lucene104Codec; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.index.mapper.MapperService; @@ -23,16 +23,11 @@ * Legacy version of {@link PerFieldMapperCodec}. This codec is preserved to give an escape hatch in case we encounter issues with new * changes in {@link PerFieldMapperCodec}. */ -public final class LegacyPerFieldMapperCodec extends Lucene103Codec { +public final class LegacyPerFieldMapperCodec extends Lucene104Codec { private final PerFieldFormatSupplier formatSupplier; - public LegacyPerFieldMapperCodec( - Lucene103Codec.Mode compressionMode, - MapperService mapperService, - BigArrays bigArrays, - ThreadPool threadPool - ) { + public LegacyPerFieldMapperCodec(Mode compressionMode, MapperService mapperService, BigArrays bigArrays, ThreadPool threadPool) { super(compressionMode); this.formatSupplier = new PerFieldFormatSupplier(mapperService, bigArrays, threadPool); // If the below assertion fails, it is a sign that Lucene released a new codec. You must create a copy of the current Elasticsearch diff --git a/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java b/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java index 642956959425e..28e7728f97688 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java +++ b/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java @@ -68,7 +68,7 @@ public class PerFieldFormatSupplier { private static final DocValuesFormat docValuesFormat = new Lucene90DocValuesFormat(); private final KnnVectorsFormat knnVectorsFormat; private static final ES812PostingsFormat es812PostingsFormat = new ES812PostingsFormat(); - private static final PostingsFormat completionPostingsFormat = PostingsFormat.forName("Completion101"); + private static final PostingsFormat completionPostingsFormat = PostingsFormat.forName("Completion104"); private final ES87BloomFilterPostingsFormat bloomFilterPostingsFormat; private final MapperService mapperService; @@ -95,7 +95,7 @@ private static PostingsFormat getDefaultPostingsFormat(final MapperService mappe if (IndexSettings.USE_ES_812_POSTINGS_FORMAT.get(mapperService.getIndexSettings().getSettings())) { return es812PostingsFormat; } else { - return Elasticsearch92Lucene103Codec.DEFAULT_POSTINGS_FORMAT; + return Elasticsearch93Lucene104Codec.DEFAULT_POSTINGS_FORMAT; } } else { // our own posting format using PFOR, used for logsdb and tsdb indices by default @@ -120,7 +120,8 @@ private static KnnVectorsFormat getDefaultKnnVectorsFormat(final MapperService m DEFAULT_BEAM_WIDTH, DenseVectorFieldMapper.ElementType.FLOAT, maxMergingWorkers, - mergingExecutorService + mergingExecutorService, + -1 ); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/PerFieldMapperCodec.java b/server/src/main/java/org/elasticsearch/index/codec/PerFieldMapperCodec.java index e4d1b2ebd18d4..4c18e3926a91e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/PerFieldMapperCodec.java +++ b/server/src/main/java/org/elasticsearch/index/codec/PerFieldMapperCodec.java @@ -27,7 +27,7 @@ * per index in real time via the mapping API. If no specific postings format or vector format is * configured for a specific field the default postings or vector format is used. */ -public final class PerFieldMapperCodec extends Elasticsearch92Lucene103Codec { +public final class PerFieldMapperCodec extends Elasticsearch93Lucene104Codec { private final PerFieldFormatSupplier formatSupplier; diff --git a/server/src/main/java/org/elasticsearch/index/codec/postings/ES812ScoreSkipReader.java b/server/src/main/java/org/elasticsearch/index/codec/postings/ES812ScoreSkipReader.java index f76e1026945e6..ee7421ce9eab7 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/postings/ES812ScoreSkipReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/postings/ES812ScoreSkipReader.java @@ -19,17 +19,14 @@ */ package org.elasticsearch.index.codec.postings; -import org.apache.lucene.index.Impact; +import org.apache.lucene.index.FreqAndNormBuffer; import org.apache.lucene.index.Impacts; import org.apache.lucene.store.ByteArrayDataInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.ArrayUtil; import java.io.IOException; -import java.util.AbstractList; import java.util.Arrays; -import java.util.List; -import java.util.RandomAccess; final class ES812ScoreSkipReader extends ES812SkipReader { @@ -38,16 +35,17 @@ final class ES812ScoreSkipReader extends ES812SkipReader { private final ByteArrayDataInput badi = new ByteArrayDataInput(); private final Impacts impacts; private int numLevels = 1; - private final MutableImpactList[] perLevelImpacts; + private final FreqAndNormBuffer[] perLevelImpacts; ES812ScoreSkipReader(IndexInput skipStream, int maxSkipLevels, boolean hasPos, boolean hasOffsets, boolean hasPayloads) { super(skipStream, maxSkipLevels, hasPos, hasOffsets, hasPayloads); this.impactData = new byte[maxSkipLevels][]; Arrays.fill(impactData, new byte[0]); this.impactDataLength = new int[maxSkipLevels]; - this.perLevelImpacts = new MutableImpactList[maxSkipLevels]; + this.perLevelImpacts = new FreqAndNormBuffer[maxSkipLevels]; for (int i = 0; i < perLevelImpacts.length; ++i) { - perLevelImpacts[i] = new MutableImpactList(); + perLevelImpacts[i] = new FreqAndNormBuffer(); + perLevelImpacts[i].add(Integer.MAX_VALUE, 1L); } impacts = new Impacts() { @@ -62,7 +60,7 @@ public int getDocIdUpTo(int level) { } @Override - public List getImpacts(int level) { + public FreqAndNormBuffer getImpacts(int level) { assert level < numLevels; if (impactDataLength[level] > 0) { badi.reset(impactData[level], 0, impactDataLength[level]); @@ -83,9 +81,9 @@ public int skipTo(int target) throws IOException { // End of postings don't have skip data anymore, so we fill with dummy data // like SlowImpactsEnum. numLevels = 1; - perLevelImpacts[0].length = 1; - perLevelImpacts[0].impacts[0].freq = Integer.MAX_VALUE; - perLevelImpacts[0].impacts[0].norm = 1L; + perLevelImpacts[0].size = 1; + perLevelImpacts[0].freqs[0] = Integer.MAX_VALUE; + perLevelImpacts[0].norms[0] = 1L; impactDataLength[0] = 0; } return result; @@ -105,19 +103,13 @@ protected void readImpacts(int level, IndexInput skipStream) throws IOException impactDataLength[level] = length; } - static MutableImpactList readImpacts(ByteArrayDataInput in, MutableImpactList reuse) { + static FreqAndNormBuffer readImpacts(ByteArrayDataInput in, FreqAndNormBuffer reuse) { int maxNumImpacts = in.length(); // at most one impact per byte - if (reuse.impacts.length < maxNumImpacts) { - int oldLength = reuse.impacts.length; - reuse.impacts = ArrayUtil.grow(reuse.impacts, maxNumImpacts); - for (int i = oldLength; i < reuse.impacts.length; ++i) { - reuse.impacts[i] = new Impact(Integer.MAX_VALUE, 1L); - } - } + reuse.growNoCopy(maxNumImpacts); int freq = 0; long norm = 0; - int length = 0; + int size = 0; while (in.getPosition() < in.length()) { int freqDelta = in.readVInt(); if ((freqDelta & 0x01) != 0) { @@ -131,27 +123,11 @@ static MutableImpactList readImpacts(ByteArrayDataInput in, MutableImpactList re freq += 1 + (freqDelta >>> 1); norm++; } - Impact impact = reuse.impacts[length]; - impact.freq = freq; - impact.norm = norm; - length++; + reuse.freqs[size] = freq; + reuse.norms[size] = norm; + size++; } - reuse.length = length; + reuse.size = size; return reuse; } - - static class MutableImpactList extends AbstractList implements RandomAccess { - int length = 1; - Impact[] impacts = new Impact[] { new Impact(Integer.MAX_VALUE, 1L) }; - - @Override - public Impact get(int index) { - return impacts[index]; - } - - @Override - public int size() { - return length; - } - } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/postings/ES812SkipWriter.java b/server/src/main/java/org/elasticsearch/index/codec/postings/ES812SkipWriter.java index 98c516fc890e8..5cf542fef36e4 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/postings/ES812SkipWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/postings/ES812SkipWriter.java @@ -21,8 +21,8 @@ package org.elasticsearch.index.codec.postings; import org.apache.lucene.codecs.CompetitiveImpactAccumulator; +import org.apache.lucene.codecs.Impact; import org.apache.lucene.codecs.MultiLevelSkipListWriter; -import org.apache.lucene.index.Impact; import org.apache.lucene.store.ByteBuffersDataOutput; import org.apache.lucene.store.DataOutput; import org.apache.lucene.store.IndexOutput; diff --git a/server/src/main/java/org/elasticsearch/index/codec/tsdb/ES93TSDBDefaultCompressionLucene103Codec.java b/server/src/main/java/org/elasticsearch/index/codec/tsdb/ES93TSDBDefaultCompressionLucene103Codec.java index 36683339d7137..77d559e93da49 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/tsdb/ES93TSDBDefaultCompressionLucene103Codec.java +++ b/server/src/main/java/org/elasticsearch/index/codec/tsdb/ES93TSDBDefaultCompressionLucene103Codec.java @@ -9,15 +9,15 @@ package org.elasticsearch.index.codec.tsdb; -import org.apache.lucene.codecs.lucene103.Lucene103Codec; +import org.apache.lucene.codecs.lucene104.Lucene104Codec; public class ES93TSDBDefaultCompressionLucene103Codec extends AbstractTSDBSyntheticIdCodec { /** Public no-arg constructor, needed for SPI loading at read-time. */ public ES93TSDBDefaultCompressionLucene103Codec() { - this(new Lucene103Codec()); + this(new Lucene104Codec()); } - public ES93TSDBDefaultCompressionLucene103Codec(Lucene103Codec delegate) { + public ES93TSDBDefaultCompressionLucene103Codec(Lucene104Codec delegate) { super("ES93TSDBDefaultCompressionLucene103Codec", delegate); } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/AbstractHnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/AbstractHnswVectorsFormat.java index 618ece60ae251..403f50199ab84 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/AbstractHnswVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/AbstractHnswVectorsFormat.java @@ -17,9 +17,6 @@ import java.util.concurrent.ExecutorService; -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; @@ -42,23 +39,19 @@ public abstract class AbstractHnswVectorsFormat extends KnnVectorsFormat { protected final int numMergeWorkers; protected final TaskExecutor mergeExec; - /** Constructs a format using default graph construction parameters */ - protected AbstractHnswVectorsFormat(String name) { - this(name, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); - } + /** + * Always build HNSW graph regardless of segment size + */ + public static final int DEFAULT_HNSW_GRAPH_THRESHOLD = 0; /** - * Constructs a format using the given graph construction parameters. - * - * @param maxConn the maximum number of connections to a node in the HNSW graph - * @param beamWidth the size of the queue maintained during graph construction. + * The minimum expected search cost before building an HNSW graph becomes worthwhile. + * Below this threshold, brute-force search is efficient enough that graph construction overhead isn't worthwhile. */ - protected AbstractHnswVectorsFormat(String name, int maxConn, int beamWidth) { - this(name, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); - } + protected final int hnswGraphThreshold; /** - * Constructs a format using the given graph construction parameters and scalar quantization. + * Constructs a format using the given graph construction parameters, merge settings, and HNSW graph threshold. * * @param maxConn the maximum number of connections to a node in the HNSW graph * @param beamWidth the size of the queue maintained during graph construction. @@ -66,8 +59,16 @@ protected AbstractHnswVectorsFormat(String name, int maxConn, int beamWidth) { * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are * generated by this format to do the merge + * @param hnswGraphThreshold the minimum expected search cost before building an HNSW graph; if negative, use format default */ - protected AbstractHnswVectorsFormat(String name, int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + protected AbstractHnswVectorsFormat( + String name, + int maxConn, + int beamWidth, + int numMergeWorkers, + ExecutorService mergeExec, + int hnswGraphThreshold + ) { super(name); if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { throw new IllegalArgumentException( @@ -90,6 +91,14 @@ protected AbstractHnswVectorsFormat(String name, int maxConn, int beamWidth, int } else { this.mergeExec = null; } + this.hnswGraphThreshold = hnswGraphThreshold; + } + + /** + * Resolves the HNSW graph threshold. If the given threshold is negative (not set), returns the provided default. + */ + protected static int resolveThreshold(int threshold, int defaultThreshold) { + return threshold >= 0 ? threshold : defaultThreshold; } protected abstract FlatVectorsFormat flatVectorsFormat(); @@ -108,6 +117,7 @@ public String toString() { + maxConn + ", beamWidth=" + beamWidth + + (hnswGraphThreshold > 0 ? ", hnswGraphThreshold=" + hnswGraphThreshold : "") + ", flatVectorFormat=" + flatVectorsFormat() + ")"; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormat.java index 8f55b2dd2d603..63637e3dcc7bd 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormat.java @@ -46,7 +46,7 @@ public ES814HnswScalarQuantizedVectorsFormat( int numMergeWorkers, ExecutorService mergeExec ) { - super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec, DEFAULT_HNSW_GRAPH_THRESHOLD); this.flatVectorsFormat = new ES814ScalarQuantizedVectorsFormat(confidenceInterval, bits, compress); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java index 9abcb1b75ab6c..dd125b68d292a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java @@ -9,6 +9,7 @@ package org.elasticsearch.index.codec.vectors; +import org.apache.lucene.backward_codecs.lucene99.Lucene99ScalarQuantizedVectorsReader; import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; @@ -17,8 +18,6 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; -import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; @@ -40,7 +39,6 @@ import java.io.IOException; import java.util.Map; -import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.DYNAMIC_CONFIDENCE_INTERVAL; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; public class ES814ScalarQuantizedVectorsFormat extends FlatVectorsFormat { @@ -62,6 +60,9 @@ public class ES814ScalarQuantizedVectorsFormat extends FlatVectorsFormat { /** The maximum confidence interval */ private static final float MAXIMUM_CONFIDENCE_INTERVAL = 1f; + /** Dynamic confidence interval */ + public static final float DYNAMIC_CONFIDENCE_INTERVAL = 0f; + /** * Controls the confidence interval used to scalar quantize the vectors the default value is * calculated as `1-1/(vector_dimensions + 1)` diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java index 3e3d963a70f31..40b67158542ee 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java @@ -19,6 +19,10 @@ import java.io.IOException; import java.util.concurrent.ExecutorService; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; + public class ES815HnswBitVectorsFormat extends AbstractHnswVectorsFormat { static final String NAME = "ES815HnswBitVectorsFormat"; @@ -26,11 +30,11 @@ public class ES815HnswBitVectorsFormat extends AbstractHnswVectorsFormat { private static final FlatVectorsFormat flatVectorsFormat = new ES815BitFlatVectorsFormat(); public ES815HnswBitVectorsFormat() { - super(NAME); + super(NAME, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, DEFAULT_HNSW_GRAPH_THRESHOLD); } public ES815HnswBitVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { - super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec, DEFAULT_HNSW_GRAPH_THRESHOLD); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/Lucene99ScalarQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/Lucene99ScalarQuantizedVectorsWriter.java new file mode 100644 index 0000000000000..c230c81feaa93 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/Lucene99ScalarQuantizedVectorsWriter.java @@ -0,0 +1,1210 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.DocIDMerger; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntArrayList; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedVectorsReader; +import org.apache.lucene.util.quantization.ScalarQuantizer; +import org.elasticsearch.core.SuppressForbidden; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; + +/** + * Copied from Lucene 10.3. + */ +@SuppressForbidden(reason = "Lucene classes") +public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWriter { + + private static final long SHALLOW_RAM_BYTES_USED = shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsWriter.class); + + static final String QUANTIZED_VECTOR_COMPONENT = "QVEC"; + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + + static final int VERSION_START = 0; + static final int VERSION_ADD_BITS = 1; + static final String META_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatData"; + static final String META_EXTENSION = "vemq"; + static final String VECTOR_DATA_EXTENSION = "veq"; + + private static final float MINIMUM_CONFIDENCE_INTERVAL = 0.9f; + + /** Dynamic confidence interval */ + public static final float DYNAMIC_CONFIDENCE_INTERVAL = 0f; + + static float calculateDefaultConfidenceInterval(int vectorDimension) { + return Math.max(MINIMUM_CONFIDENCE_INTERVAL, 1f - (1f / (vectorDimension + 1))); + } + + // Used for determining when merged quantiles shifted too far from individual segment quantiles. + // When merging quantiles from various segments, we need to ensure that the new quantiles + // are not exceptionally different from an individual segments quantiles. + // This would imply that the quantization buckets would shift too much + // for floating point values and justify recalculating the quantiles. This helps preserve + // accuracy of the calculated quantiles, even in adversarial cases such as vector clustering. + // This number was determined via empirical testing + private static final float QUANTILE_RECOMPUTE_LIMIT = 32; + // Used for determining if a new quantization state requires a re-quantization + // for a given segment. + // This ensures that in expectation 4/5 of the vector would be unchanged by requantization. + // Furthermore, only those values where the value is within 1/5 of the centre of a quantization + // bin will be changed. In these cases the error introduced by snapping one way or another + // is small compared to the error introduced by quantization in the first place. Furthermore, + // empirical testing showed that the relative error by not requantizing is small (compared to + // the quantization error) and the condition is sensitive enough to detect all adversarial cases, + // such as merging clustered data. + private static final float REQUANTIZATION_LIMIT = 0.2f; + private final SegmentWriteState segmentWriteState; + + private final List fields = new ArrayList<>(); + private final IndexOutput meta, quantizedVectorData; + private final Float confidenceInterval; + private final FlatVectorsWriter rawVectorDelegate; + private final byte bits; + private final boolean compress; + private final int version; + private boolean finished; + + public Lucene99ScalarQuantizedVectorsWriter( + SegmentWriteState state, + Float confidenceInterval, + FlatVectorsWriter rawVectorDelegate, + FlatVectorsScorer scorer + ) throws IOException { + this(state, VERSION_START, confidenceInterval, (byte) 7, false, rawVectorDelegate, scorer); + if (confidenceInterval != null && confidenceInterval == 0) { + throw new IllegalArgumentException("confidenceInterval cannot be set to zero"); + } + } + + public Lucene99ScalarQuantizedVectorsWriter( + SegmentWriteState state, + Float confidenceInterval, + byte bits, + boolean compress, + FlatVectorsWriter rawVectorDelegate, + FlatVectorsScorer scorer + ) throws IOException { + this(state, VERSION_ADD_BITS, confidenceInterval, bits, compress, rawVectorDelegate, scorer); + } + + private Lucene99ScalarQuantizedVectorsWriter( + SegmentWriteState state, + int version, + Float confidenceInterval, + byte bits, + boolean compress, + FlatVectorsWriter rawVectorDelegate, + FlatVectorsScorer scorer + ) throws IOException { + super(scorer); + this.confidenceInterval = confidenceInterval; + this.bits = bits; + this.compress = compress; + this.version = version; + segmentWriteState = state; + String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, META_EXTENSION); + + String quantizedVectorDataFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + VECTOR_DATA_EXTENSION + ); + this.rawVectorDelegate = rawVectorDelegate; + boolean success = false; + try { + meta = state.directory.createOutput(metaFileName, state.context); + quantizedVectorData = state.directory.createOutput(quantizedVectorDataFileName, state.context); + + CodecUtil.writeIndexHeader(meta, META_CODEC_NAME, version, state.segmentInfo.getId(), state.segmentSuffix); + CodecUtil.writeIndexHeader( + quantizedVectorData, + VECTOR_DATA_CODEC_NAME, + version, + state.segmentInfo.getId(), + state.segmentSuffix + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + if (bits <= 4 && fieldInfo.getVectorDimension() % 2 != 0) { + throw new IllegalArgumentException( + "bits=" + bits + " is not supported for odd vector dimensions; vector dimension=" + fieldInfo.getVectorDimension() + ); + } + @SuppressWarnings("unchecked") + FieldWriter quantizedWriter = new FieldWriter( + confidenceInterval, + bits, + compress, + fieldInfo, + segmentWriteState.infoStream, + (FlatFieldVectorsWriter) rawVectorDelegate + ); + fields.add(quantizedWriter); + return quantizedWriter; + } + return rawVectorDelegate; + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + // Since we know we will not be searching for additional indexing, we can just write the + // the vectors directly to the new segment. + // No need to use temporary file as we don't have to re-open for reading + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + ScalarQuantizer mergedQuantizationState = mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval, bits); + MergedQuantizedVectorValues byteVectorValues = MergedQuantizedVectorValues.mergeQuantizedByteVectorValues( + fieldInfo, + mergeState, + mergedQuantizationState + ); + long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); + DocsWithFieldSet docsWithField = writeQuantizedVectorData(quantizedVectorData, byteVectorValues, bits, compress); + long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + confidenceInterval, + bits, + compress, + mergedQuantizationState.getLowerQuantile(), + mergedQuantizationState.getUpperQuantile(), + docsWithField + ); + } + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + // Simply merge the underlying delegate, which just copies the raw vector data to a new + // segment file + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + ScalarQuantizer mergedQuantizationState = mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval, bits); + return mergeOneFieldToIndex(segmentWriteState, fieldInfo, mergeState, mergedQuantizationState); + } + // We only merge the delegate, since the field type isn't float32, quantization wasn't + // supported, so bypass it. + return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState); + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter field : fields) { + ScalarQuantizer quantizer = field.createQuantizer(); + if (sortMap == null) { + writeField(field, maxDoc, quantizer); + } else { + writeSortingField(field, maxDoc, sortMap, quantizer); + } + field.finish(); + } + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + rawVectorDelegate.finish(); + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (quantizedVectorData != null) { + CodecUtil.writeFooter(quantizedVectorData); + } + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + // the field tracks the delegate field usage + total += field.ramBytesUsed(); + } + return total; + } + + private void writeField(FieldWriter fieldData, int maxDoc, ScalarQuantizer scalarQuantizer) throws IOException { + // write vector values + long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); + writeQuantizedVectors(fieldData, scalarQuantizer); + long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset; + + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + vectorDataLength, + confidenceInterval, + bits, + compress, + scalarQuantizer.getLowerQuantile(), + scalarQuantizer.getUpperQuantile(), + fieldData.getDocsWithFieldSet() + ); + } + + private void writeMeta( + FieldInfo field, + int maxDoc, + long vectorDataOffset, + long vectorDataLength, + Float confidenceInterval, + byte bits, + boolean compress, + Float lowerQuantile, + Float upperQuantile, + DocsWithFieldSet docsWithField + ) throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + meta.writeVInt(field.getVectorDimension()); + int count = docsWithField.cardinality(); + meta.writeInt(count); + if (count > 0) { + assert Float.isFinite(lowerQuantile) && Float.isFinite(upperQuantile); + if (version >= VERSION_ADD_BITS) { + meta.writeInt(confidenceInterval == null ? -1 : Float.floatToIntBits(confidenceInterval)); + meta.writeByte(bits); + meta.writeByte(compress ? (byte) 1 : (byte) 0); + } else { + assert confidenceInterval == null || confidenceInterval != DYNAMIC_CONFIDENCE_INTERVAL; + meta.writeInt( + Float.floatToIntBits( + confidenceInterval == null ? calculateDefaultConfidenceInterval(field.getVectorDimension()) : confidenceInterval + ) + ); + } + meta.writeInt(Float.floatToIntBits(lowerQuantile)); + meta.writeInt(Float.floatToIntBits(upperQuantile)); + } + // write docIDs + OrdToDocDISIReaderConfiguration.writeStoredMeta( + DIRECT_MONOTONIC_BLOCK_SHIFT, + meta, + quantizedVectorData, + count, + maxDoc, + docsWithField + ); + } + + private void writeQuantizedVectors(FieldWriter fieldData, ScalarQuantizer scalarQuantizer) throws IOException { + byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()]; + byte[] compressedVector = fieldData.compress + ? OffHeapQuantizedByteVectorValues.compressedArray(fieldData.fieldInfo.getVectorDimension(), bits) + : null; + final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null; + assert fieldData.getVectors().isEmpty() || scalarQuantizer != null; + for (float[] v : fieldData.getVectors()) { + if (fieldData.normalize) { + System.arraycopy(v, 0, copy, 0, copy.length); + VectorUtil.l2normalize(copy); + v = copy; + } + + float offsetCorrection = scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction()); + if (compressedVector != null) { + OffHeapQuantizedByteVectorValues.compressBytes(vector, compressedVector); + quantizedVectorData.writeBytes(compressedVector, compressedVector.length); + } else { + quantizedVectorData.writeBytes(vector, vector.length); + } + offsetBuffer.putFloat(offsetCorrection); + quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length); + offsetBuffer.rewind(); + } + } + + private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap, ScalarQuantizer scalarQuantizer) + throws IOException { + final int[] ordMap = new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); + writeSortedQuantizedVectors(fieldData, ordMap, scalarQuantizer); + long quantizedVectorLength = quantizedVectorData.getFilePointer() - vectorDataOffset; + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + quantizedVectorLength, + confidenceInterval, + bits, + compress, + scalarQuantizer.getLowerQuantile(), + scalarQuantizer.getUpperQuantile(), + newDocsWithField + ); + } + + private void writeSortedQuantizedVectors(FieldWriter fieldData, int[] ordMap, ScalarQuantizer scalarQuantizer) throws IOException { + byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()]; + byte[] compressedVector = fieldData.compress + ? OffHeapQuantizedByteVectorValues.compressedArray(fieldData.fieldInfo.getVectorDimension(), bits) + : null; + final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null; + for (int ordinal : ordMap) { + float[] v = fieldData.getVectors().get(ordinal); + if (fieldData.normalize) { + System.arraycopy(v, 0, copy, 0, copy.length); + VectorUtil.l2normalize(copy); + v = copy; + } + float offsetCorrection = scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction()); + if (compressedVector != null) { + OffHeapQuantizedByteVectorValues.compressBytes(vector, compressedVector); + quantizedVectorData.writeBytes(compressedVector, compressedVector.length); + } else { + quantizedVectorData.writeBytes(vector, vector.length); + } + offsetBuffer.putFloat(offsetCorrection); + quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length); + offsetBuffer.rewind(); + } + } + + private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + SegmentWriteState segmentWriteState, + FieldInfo fieldInfo, + MergeState mergeState, + ScalarQuantizer mergedQuantizationState + ) throws IOException { + if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + QUANTIZED_VECTOR_COMPONENT, + "quantized field=" + + " confidenceInterval=" + + confidenceInterval + + " minQuantile=" + + mergedQuantizationState.getLowerQuantile() + + " maxQuantile=" + + mergedQuantizationState.getUpperQuantile() + ); + } + long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); + IndexOutput tempQuantizedVectorData = segmentWriteState.directory.createTempOutput( + quantizedVectorData.getName(), + "temp", + segmentWriteState.context + ); + IndexInput quantizationDataInput = null; + boolean success = false; + try { + MergedQuantizedVectorValues byteVectorValues = MergedQuantizedVectorValues.mergeQuantizedByteVectorValues( + fieldInfo, + mergeState, + mergedQuantizationState + ); + DocsWithFieldSet docsWithField = writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues, bits, compress); + CodecUtil.writeFooter(tempQuantizedVectorData); + IOUtils.close(tempQuantizedVectorData); + quantizationDataInput = segmentWriteState.directory.openInput(tempQuantizedVectorData.getName(), segmentWriteState.context); + quantizedVectorData.copyBytes(quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength()); + long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset; + CodecUtil.retrieveChecksum(quantizationDataInput); + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + confidenceInterval, + bits, + compress, + mergedQuantizationState.getLowerQuantile(), + mergedQuantizationState.getUpperQuantile(), + docsWithField + ); + success = true; + final IndexInput finalQuantizationDataInput = quantizationDataInput; + return new ScalarQuantizedCloseableRandomVectorScorerSupplier(() -> { + IOUtils.close(finalQuantizationDataInput); + segmentWriteState.directory.deleteFile(tempQuantizedVectorData.getName()); + }, + docsWithField.cardinality(), + vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + mergedQuantizationState, + compress, + fieldInfo.getVectorSimilarityFunction(), + vectorsScorer, + quantizationDataInput + ) + ) + ); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(tempQuantizedVectorData, quantizationDataInput); + IOUtils.deleteFilesIgnoringExceptions(segmentWriteState.directory, tempQuantizedVectorData.getName()); + } + } + } + + static ScalarQuantizer mergeQuantiles(List quantizationStates, IntArrayList segmentSizes, byte bits) { + assert quantizationStates.size() == segmentSizes.size(); + if (quantizationStates.isEmpty()) { + return null; + } + float lowerQuantile = 0f; + float upperQuantile = 0f; + int totalCount = 0; + for (int i = 0; i < quantizationStates.size(); i++) { + if (quantizationStates.get(i) == null) { + return null; + } + lowerQuantile += quantizationStates.get(i).getLowerQuantile() * segmentSizes.get(i); + upperQuantile += quantizationStates.get(i).getUpperQuantile() * segmentSizes.get(i); + totalCount += segmentSizes.get(i); + if (quantizationStates.get(i).getBits() != bits) { + return null; + } + } + lowerQuantile /= totalCount; + upperQuantile /= totalCount; + return new ScalarQuantizer(lowerQuantile, upperQuantile, bits); + } + + /** + * Returns true if the quantiles of the merged state are too far from the quantiles of the + * individual states. + * + * @param mergedQuantizationState The merged quantization state + * @param quantizationStates The quantization states of the individual segments + * @return true if the quantiles should be recomputed + */ + static boolean shouldRecomputeQuantiles(ScalarQuantizer mergedQuantizationState, List quantizationStates) { + // calculate the limit for the quantiles to be considered too far apart + // We utilize upper & lower here to determine if the new upper and merged upper would + // drastically + // change the quantization buckets for floats + // This is a fairly conservative check. + float limit = (mergedQuantizationState.getUpperQuantile() - mergedQuantizationState.getLowerQuantile()) / QUANTILE_RECOMPUTE_LIMIT; + for (ScalarQuantizer quantizationState : quantizationStates) { + if (Math.abs(quantizationState.getUpperQuantile() - mergedQuantizationState.getUpperQuantile()) > limit) { + return true; + } + if (Math.abs(quantizationState.getLowerQuantile() - mergedQuantizationState.getLowerQuantile()) > limit) { + return true; + } + } + return false; + } + + private static QuantizedVectorsReader getQuantizedKnnVectorsReader(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof QuantizedVectorsReader reader) { + return reader; + } + return null; + } + + private static ScalarQuantizer getQuantizedState(KnnVectorsReader vectorsReader, String fieldName) { + QuantizedVectorsReader reader = getQuantizedKnnVectorsReader(vectorsReader, fieldName); + if (reader != null) { + return reader.getQuantizationState(fieldName); + } + return null; + } + + /** + * Merges the quantiles of the segments and recalculates the quantiles if necessary. + * + * @param mergeState The merge state + * @param fieldInfo The field info + * @param confidenceInterval The confidence interval + * @param bits The number of bits + * @return The merged quantiles + * @throws IOException If there is a low-level I/O error + */ + public static ScalarQuantizer mergeAndRecalculateQuantiles( + MergeState mergeState, + FieldInfo fieldInfo, + Float confidenceInterval, + byte bits + ) throws IOException { + assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32); + List quantizationStates = new ArrayList<>(mergeState.liveDocs.length); + IntArrayList segmentSizes = new IntArrayList(mergeState.liveDocs.length); + for (int i = 0; i < mergeState.liveDocs.length; i++) { + FloatVectorValues fvv; + if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name) + && (fvv = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name)) != null + && fvv.size() > 0) { + ScalarQuantizer quantizationState = getQuantizedState(mergeState.knnVectorsReaders[i], fieldInfo.name); + // If we have quantization state, we can utilize that to make merging cheaper + quantizationStates.add(quantizationState); + segmentSizes.add(fvv.size()); + } + } + ScalarQuantizer mergedQuantiles = mergeQuantiles(quantizationStates, segmentSizes, bits); + // Segments no providing quantization state indicates that their quantiles were never + // calculated. + // To be safe, we should always recalculate given a sample set over all the float vectors in the + // merged + // segment view + if (mergedQuantiles == null + // For smaller `bits` values, we should always recalculate the quantiles + // TODO: this is very conservative, could we reuse information for even int4 quantization? + || bits <= 4 + || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) { + int numVectors = 0; + DocIdSetIterator iter = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState).iterator(); + // iterate vectorValues and increment numVectors + for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) { + numVectors++; + } + return buildScalarQuantizer( + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState), + numVectors, + fieldInfo.getVectorSimilarityFunction(), + confidenceInterval, + bits + ); + } + return mergedQuantiles; + } + + static ScalarQuantizer buildScalarQuantizer( + FloatVectorValues floatVectorValues, + int numVectors, + VectorSimilarityFunction vectorSimilarityFunction, + Float confidenceInterval, + byte bits + ) throws IOException { + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + vectorSimilarityFunction = VectorSimilarityFunction.DOT_PRODUCT; + } + if (confidenceInterval != null && confidenceInterval == DYNAMIC_CONFIDENCE_INTERVAL) { + return ScalarQuantizer.fromVectorsAutoInterval(floatVectorValues, vectorSimilarityFunction, numVectors, bits); + } + return ScalarQuantizer.fromVectors( + floatVectorValues, + confidenceInterval == null ? calculateDefaultConfidenceInterval(floatVectorValues.dimension()) : confidenceInterval, + numVectors, + bits + ); + } + + /** + * Returns true if the quantiles of the new quantization state are too far from the quantiles of + * the existing quantization state. This would imply that floating point values would slightly + * shift quantization buckets. + * + * @param existingQuantiles The existing quantiles for a segment + * @param newQuantiles The new quantiles for a segment, could be merged, or fully re-calculated + * @return true if the floating point values should be requantized + */ + static boolean shouldRequantize(ScalarQuantizer existingQuantiles, ScalarQuantizer newQuantiles) { + float tol = REQUANTIZATION_LIMIT * (newQuantiles.getUpperQuantile() - newQuantiles.getLowerQuantile()) / 128f; + if (Math.abs(existingQuantiles.getUpperQuantile() - newQuantiles.getUpperQuantile()) > tol) { + return true; + } + return Math.abs(existingQuantiles.getLowerQuantile() - newQuantiles.getLowerQuantile()) > tol; + } + + /** + * Writes the vector values to the output and returns a set of documents that contains vectors. + */ + public static DocsWithFieldSet writeQuantizedVectorData( + IndexOutput output, + QuantizedByteVectorValues quantizedByteVectorValues, + byte bits, + boolean compress + ) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + final byte[] compressedVector = compress + ? OffHeapQuantizedByteVectorValues.compressedArray(quantizedByteVectorValues.dimension(), bits) + : null; + KnnVectorValues.DocIndexIterator iter = quantizedByteVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { + // write vector + byte[] binaryValue = quantizedByteVectorValues.vectorValue(iter.index()); + assert binaryValue.length == quantizedByteVectorValues.dimension() + : "dim=" + quantizedByteVectorValues.dimension() + " len=" + binaryValue.length; + if (compressedVector != null) { + OffHeapQuantizedByteVectorValues.compressBytes(binaryValue, compressedVector); + output.writeBytes(compressedVector, compressedVector.length); + } else { + output.writeBytes(binaryValue, binaryValue.length); + } + output.writeInt(Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant(iter.index()))); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, quantizedVectorData, rawVectorDelegate); + } + + static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private final Float confidenceInterval; + private final byte bits; + private final boolean compress; + private final InfoStream infoStream; + private final boolean normalize; + private boolean finished; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; + + FieldWriter( + Float confidenceInterval, + byte bits, + boolean compress, + FieldInfo fieldInfo, + InfoStream infoStream, + FlatFieldVectorsWriter indexWriter + ) { + super(); + this.confidenceInterval = confidenceInterval; + this.bits = bits; + this.fieldInfo = fieldInfo; + this.normalize = fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE; + this.infoStream = infoStream; + this.compress = compress; + this.flatFieldVectorsWriter = Objects.requireNonNull(indexWriter); + } + + @Override + public boolean isFinished() { + return finished && flatFieldVectorsWriter.isFinished(); + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + assert flatFieldVectorsWriter.isFinished(); + finished = true; + } + + ScalarQuantizer createQuantizer() throws IOException { + assert flatFieldVectorsWriter.isFinished(); + List floatVectors = flatFieldVectorsWriter.getVectors(); + if (floatVectors.size() == 0) { + return new ScalarQuantizer(0, 0, bits); + } + ScalarQuantizer quantizer = buildScalarQuantizer( + new FloatVectorWrapper(floatVectors), + floatVectors.size(), + fieldInfo.getVectorSimilarityFunction(), + confidenceInterval, + bits + ); + if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { + infoStream.message( + QUANTIZED_VECTOR_COMPONENT, + "quantized field=" + + " confidenceInterval=" + + confidenceInterval + + " bits=" + + bits + + " minQuantile=" + + quantizer.getLowerQuantile() + + " maxQuantile=" + + quantizer.getUpperQuantile() + ); + } + return quantizer; + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += flatFieldVectorsWriter.ramBytesUsed(); + return size; + } + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + flatFieldVectorsWriter.addValue(docID, vectorValue); + } + + @Override + public float[] copyValue(float[] vectorValue) { + throw new UnsupportedOperationException(); + } + + @Override + public List getVectors() { + return flatFieldVectorsWriter.getVectors(); + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return flatFieldVectorsWriter.getDocsWithFieldSet(); + } + } + + static class FloatVectorWrapper extends FloatVectorValues { + private final List vectorList; + + FloatVectorWrapper(List vectorList) { + this.vectorList = vectorList; + } + + @Override + public int dimension() { + return vectorList.get(0).length; + } + + @Override + public int size() { + return vectorList.size(); + } + + @Override + public FloatVectorValues copy() throws IOException { + return this; + } + + @Override + public float[] vectorValue(int ord) throws IOException { + if (ord < 0 || ord >= vectorList.size()) { + throw new IOException("vector ord " + ord + " out of bounds"); + } + return vectorList.get(ord); + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + } + + static class QuantizedByteVectorValueSub extends DocIDMerger.Sub { + private final QuantizedByteVectorValues values; + private final KnnVectorValues.DocIndexIterator iterator; + + QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) { + super(docMap); + this.values = values; + iterator = values.iterator(); + assert iterator.docID() == -1; + } + + @Override + public int nextDoc() throws IOException { + return iterator.nextDoc(); + } + + public int index() { + return iterator.index(); + } + } + + /** Returns a merged view over all the segment's {@link QuantizedByteVectorValues}. */ + static class MergedQuantizedVectorValues extends QuantizedByteVectorValues { + public static MergedQuantizedVectorValues mergeQuantizedByteVectorValues( + FieldInfo fieldInfo, + MergeState mergeState, + ScalarQuantizer scalarQuantizer + ) throws IOException { + assert fieldInfo != null && fieldInfo.hasVectorValues(); + + List subs = new ArrayList<>(); + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name)) { + QuantizedVectorsReader reader = getQuantizedKnnVectorsReader(mergeState.knnVectorsReaders[i], fieldInfo.name); + assert scalarQuantizer != null; + final QuantizedByteVectorValueSub sub; + // Either our quantization parameters are way different than the merged ones + // Or we have never been quantized. + if (reader == null || reader.getQuantizationState(fieldInfo.name) == null + // For smaller `bits` values, we should always recalculate the quantiles + // TODO: this is very conservative, could we reuse information for even int4 + // quantization? + || scalarQuantizer.getBits() <= 4 + || shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) { + FloatVectorValues toQuantize = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name); + if (fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) { + toQuantize = new NormalizedFloatVectorValues(toQuantize); + } + sub = new QuantizedByteVectorValueSub( + mergeState.docMaps[i], + new QuantizedFloatVectorValues(toQuantize, fieldInfo.getVectorSimilarityFunction(), scalarQuantizer) + ); + } else { + sub = new QuantizedByteVectorValueSub( + mergeState.docMaps[i], + new OffsetCorrectedQuantizedByteVectorValues( + reader.getQuantizedVectorValues(fieldInfo.name), + fieldInfo.getVectorSimilarityFunction(), + scalarQuantizer, + reader.getQuantizationState(fieldInfo.name) + ) + ); + } + subs.add(sub); + } + } + return new MergedQuantizedVectorValues(subs, mergeState); + } + + private final List subs; + private final DocIDMerger docIdMerger; + private final int size; + + private QuantizedByteVectorValueSub current; + + private MergedQuantizedVectorValues(List subs, MergeState mergeState) throws IOException { + this.subs = subs; + docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); + int totalSize = 0; + for (QuantizedByteVectorValueSub sub : subs) { + totalSize += sub.values.size(); + } + size = totalSize; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + return current.values.vectorValue(current.index()); + } + + @Override + public DocIndexIterator iterator() { + return new CompositeIterator(); + } + + @Override + public int size() { + return size; + } + + @Override + public int dimension() { + return subs.get(0).values.dimension(); + } + + @Override + public float getScoreCorrectionConstant(int ord) throws IOException { + return current.values.getScoreCorrectionConstant(current.index()); + } + + private class CompositeIterator extends DocIndexIterator { + private int docId; + private int ord; + + CompositeIterator() { + docId = -1; + ord = -1; + } + + @Override + public int index() { + return ord; + } + + @Override + public int docID() { + return docId; + } + + @Override + public int nextDoc() throws IOException { + current = docIdMerger.next(); + if (current == null) { + docId = NO_MORE_DOCS; + ord = NO_MORE_DOCS; + } else { + docId = current.mappedDocID; + ++ord; + } + return docId; + } + + @Override + public int advance(int target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public long cost() { + return size; + } + } + } + + static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { + private final FloatVectorValues values; + private final ScalarQuantizer quantizer; + private final byte[] quantizedVector; + private int lastOrd = -1; + private float offsetValue = 0f; + + private final VectorSimilarityFunction vectorSimilarityFunction; + + QuantizedFloatVectorValues(FloatVectorValues values, VectorSimilarityFunction vectorSimilarityFunction, ScalarQuantizer quantizer) { + this.values = values; + this.quantizer = quantizer; + this.quantizedVector = new byte[values.dimension()]; + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + @Override + public float getScoreCorrectionConstant(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve score correction for different ord " + ord + " than the quantization was done for: " + lastOrd + ); + } + return offsetValue; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public int size() { + return values.size(); + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + offsetValue = quantize(ord); + lastOrd = ord; + } + return quantizedVector; + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } + + private float quantize(int ord) throws IOException { + return quantizer.quantize(values.vectorValue(ord), quantizedVector, vectorSimilarityFunction); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + } + + static final class ScalarQuantizedCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { + + private final RandomVectorScorerSupplier supplier; + private final Closeable onClose; + private final int numVectors; + + ScalarQuantizedCloseableRandomVectorScorerSupplier(Closeable onClose, int numVectors, RandomVectorScorerSupplier supplier) { + this.onClose = onClose; + this.supplier = supplier; + this.numVectors = numVectors; + } + + @Override + public UpdateableRandomVectorScorer scorer() throws IOException { + return supplier.scorer(); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return supplier.copy(); + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + @Override + public int totalVectorCount() { + return numVectors; + } + } + + static final class OffsetCorrectedQuantizedByteVectorValues extends QuantizedByteVectorValues { + + private final QuantizedByteVectorValues in; + private final VectorSimilarityFunction vectorSimilarityFunction; + private final ScalarQuantizer scalarQuantizer, oldScalarQuantizer; + + OffsetCorrectedQuantizedByteVectorValues( + QuantizedByteVectorValues in, + VectorSimilarityFunction vectorSimilarityFunction, + ScalarQuantizer scalarQuantizer, + ScalarQuantizer oldScalarQuantizer + ) { + this.in = in; + this.vectorSimilarityFunction = vectorSimilarityFunction; + this.scalarQuantizer = scalarQuantizer; + this.oldScalarQuantizer = oldScalarQuantizer; + } + + @Override + public float getScoreCorrectionConstant(int ord) throws IOException { + return scalarQuantizer.recalculateCorrectiveOffset(in.vectorValue(ord), oldScalarQuantizer, vectorSimilarityFunction); + } + + @Override + public int dimension() { + return in.dimension(); + } + + @Override + public int size() { + return in.size(); + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + return in.vectorValue(ord); + } + + @Override + public int ordToDoc(int ord) { + return in.ordToDoc(ord); + } + + @Override + public DocIndexIterator iterator() { + return in.iterator(); + } + } + + static final class NormalizedFloatVectorValues extends FloatVectorValues { + private final FloatVectorValues values; + private final float[] normalizedVector; + + NormalizedFloatVectorValues(FloatVectorValues values) { + this.values = values; + this.normalizedVector = new float[values.dimension()]; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public int size() { + return values.size(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); + return normalizedVector; + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public NormalizedFloatVectorValues copy() throws IOException { + return new NormalizedFloatVectorValues(values.copy()); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/OffHeapQuantizedByteVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/OffHeapQuantizedByteVectorValues.java new file mode 100644 index 0000000000000..5f303fa60cee9 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/OffHeapQuantizedByteVectorValues.java @@ -0,0 +1,403 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; +import org.apache.lucene.util.quantization.ScalarQuantizer; + +import java.io.IOException; +import java.nio.ByteBuffer; + +/** + * Copied from Lucene 10.3. + * Read the quantized vector values and their score correction values from the index input. This + * supports both iterated and random access. + */ +public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValues { + + final int dimension; + final int size; + final int numBytes; + final ScalarQuantizer scalarQuantizer; + final VectorSimilarityFunction similarityFunction; + final FlatVectorsScorer vectorsScorer; + final boolean compress; + + final IndexInput slice; + final byte[] binaryValue; + final ByteBuffer byteBuffer; + final int byteSize; + int lastOrd = -1; + final float[] scoreCorrectionConstant = new float[1]; + + static void decompressBytes(byte[] compressed, int numBytes) { + if (numBytes == compressed.length) { + return; + } + if (numBytes << 1 != compressed.length) { + throw new IllegalArgumentException("numBytes: " + numBytes + " does not match compressed length: " + compressed.length); + } + for (int i = 0; i < numBytes; ++i) { + compressed[numBytes + i] = (byte) (compressed[i] & 0x0F); + compressed[i] = (byte) ((compressed[i] & 0xFF) >> 4); + } + } + + static byte[] compressedArray(int dimension, byte bits) { + if (bits <= 4) { + return new byte[(dimension + 1) >> 1]; + } else { + return null; + } + } + + static void compressBytes(byte[] raw, byte[] compressed) { + if (compressed.length != ((raw.length + 1) >> 1)) { + throw new IllegalArgumentException("compressed length: " + compressed.length + " does not match raw length: " + raw.length); + } + for (int i = 0; i < compressed.length; ++i) { + int v = (raw[i] << 4) | raw[compressed.length + i]; + compressed[i] = (byte) v; + } + } + + OffHeapQuantizedByteVectorValues( + int dimension, + int size, + ScalarQuantizer scalarQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + boolean compress, + IndexInput slice + ) { + this.dimension = dimension; + this.size = size; + this.slice = slice; + this.scalarQuantizer = scalarQuantizer; + this.compress = compress; + if (scalarQuantizer.getBits() <= 4 && compress) { + this.numBytes = (dimension + 1) >> 1; + } else { + this.numBytes = dimension; + } + this.byteSize = this.numBytes + Float.BYTES; + byteBuffer = ByteBuffer.allocate(dimension); + binaryValue = byteBuffer.array(); + this.similarityFunction = similarityFunction; + this.vectorsScorer = vectorsScorer; + } + + @Override + public ScalarQuantizer getScalarQuantizer() { + return scalarQuantizer; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public byte[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return binaryValue; + } + slice.seek((long) targetOrd * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), numBytes); + slice.readFloats(scoreCorrectionConstant, 0, 1); + decompressBytes(binaryValue, numBytes); + lastOrd = targetOrd; + return binaryValue; + } + + @Override + public float getScoreCorrectionConstant(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return scoreCorrectionConstant[0]; + } + slice.seek(((long) targetOrd * byteSize) + numBytes); + slice.readFloats(scoreCorrectionConstant, 0, 1); + return scoreCorrectionConstant[0]; + } + + @Override + public IndexInput getSlice() { + return slice; + } + + @Override + public int getVectorByteLength() { + return numBytes; + } + + static OffHeapQuantizedByteVectorValues load( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + ScalarQuantizer scalarQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + boolean compress, + long quantizedVectorDataOffset, + long quantizedVectorDataLength, + IndexInput vectorData + ) throws IOException { + if (configuration.isEmpty()) { + return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer); + } + IndexInput bytesSlice = vectorData.slice("quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength); + if (configuration.isDense()) { + return new DenseOffHeapVectorValues(dimension, size, scalarQuantizer, compress, similarityFunction, vectorsScorer, bytesSlice); + } else { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + scalarQuantizer, + compress, + vectorData, + similarityFunction, + vectorsScorer, + bytesSlice + ); + } + } + + /** + * Dense vector values that are stored off-heap. This is the most common case when every doc has a + * vector. + */ + public static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { + + /** + * Create dense off-heap vector values + * + * @param dimension vector dimension + * @param size number of vectors + * @param scalarQuantizer the scalar quantizer + * @param compress whether the vectors are compressed + * @param similarityFunction the similarity function + * @param vectorsScorer the vectors scorer + * @param slice the index input slice containing the vector data + */ + public DenseOffHeapVectorValues( + int dimension, + int size, + ScalarQuantizer scalarQuantizer, + boolean compress, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice + ) { + super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice); + } + + @Override + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues( + dimension, + size, + scalarQuantizer, + compress, + similarityFunction, + vectorsScorer, + slice.clone() + ); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer vectorScorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return vectorScorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + } + + private static class SparseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + ScalarQuantizer scalarQuantizer, + boolean compress, + IndexInput dataIn, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice + ) throws IOException { + super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice); + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + scalarQuantizer, + compress, + dataIn, + similarityFunction, + vectorsScorer, + slice.clone() + ); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer vectorScorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return vectorScorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class EmptyOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { + + EmptyOffHeapVectorValues(int dimension, VectorSimilarityFunction similarityFunction, FlatVectorsScorer vectorsScorer) { + super(dimension, 0, new ScalarQuantizer(-1, 1, (byte) 7), similarityFunction, vectorsScorer, false, null); + } + + @Override + public int dimension() { + return super.dimension(); + } + + @Override + public int size() { + return 0; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public EmptyOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] vectorValue(int targetOrd) { + throw new UnsupportedOperationException(); + } + + @Override + public int ordToDoc(int ord) { + throw new UnsupportedOperationException(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] target) { + return null; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java index eac3e708dfe66..cc373305fc9cc 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java @@ -188,6 +188,10 @@ private boolean optimizeIntervals(float[] initInterval, int[] destination, float if ((Math.abs(initInterval[0] - aOpt) < 1e-8 && Math.abs(initInterval[1] - bOpt) < 1e-8)) { return true; } + if (bOpt < aOpt) { + // This can happen if the optimal interval is very small and we have numerical instability, in which case we can stop + return true; + } double newLoss = ESVectorUtil.calculateOSQLoss(vector, aOpt, bOpt, points, norm2, lambda, destination); // If the new loss is worse, don't update the interval and exit // This optimization, unlike kMeans, does not always converge to better loss diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborHood.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborHood.java index 28db0361c02df..3f51b43d49977 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborHood.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborHood.java @@ -202,9 +202,10 @@ public float score(int node) { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) { + public float bulkScore(int[] nodes, float[] scores, int numNodes) { int i = 0; final int limit = numNodes - 3; + float max = Float.NEGATIVE_INFINITY; for (; i < limit; i += 4) { ESVectorUtil.squareDistanceBulk( centers[scoringOrdinal], @@ -216,11 +217,14 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) { ); for (int j = 0; j < 4; j++) { scores[i + j] = VectorUtil.normalizeDistanceToUnitInterval(distances[j]); + max = Math.max(max, scores[i + j]); } } for (; i < numNodes; i++) { scores[i] = score(nodes[i]); + max = Math.max(max, scores[i]); } + return max; } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java index 9939872415af7..f94b94b78434c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java @@ -28,7 +28,6 @@ import java.io.IOException; -import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS; import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize; import static org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer.DEFAULT_LAMBDA; @@ -41,6 +40,9 @@ */ public class ES920DiskBBQVectorsReader extends IVFVectorsReader { + // QUERY_BITS value copied from Lucene102BinaryQuantizedVectorsFormat where it became package private + private static final byte QUERY_BITS = 4; + ES920DiskBBQVectorsReader(SegmentReadState state, GenericFlatVectorReaders.LoadFlatVectorsReader getFormatReader) throws IOException { super(state, getFormatReader); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java index 3dfce8cf43c29..3f768525ca1d3 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java @@ -268,12 +268,14 @@ public final void search(String field, float[] target, KnnCollector knnCollector "vector query dimension: " + target.length + " differs from field dimension: " + fieldInfo.getVectorDimension() ); } + final ESAcceptDocs esAcceptDocs; if (acceptDocs instanceof ESAcceptDocs) { esAcceptDocs = (ESAcceptDocs) acceptDocs; } else { esAcceptDocs = null; } + FloatVectorValues values = getReaderForField(field).getFloatVectorValues(field); int numVectors = values.size(); // TODO returning cost 0 in ESAcceptDocs.ESAcceptDocsAll feels wrong? cost is related to the number of matching documents? diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsReader.java index 10ee20a2acde5..0258bb3176308 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsReader.java @@ -377,6 +377,11 @@ public VectorScorer scorer(float[] query) throws IOException { return quantizedVectorValues.scorer(query); } + @Override + public VectorScorer rescorer(float[] floats) throws IOException { + return in.rescorer(floats); + } + protected BinarizedByteVectorValues getQuantizedVectorValues() throws IOException { return quantizedVectorValues; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java index 2f5b425d60119..bc5fc14443a06 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java @@ -30,6 +30,10 @@ import java.io.IOException; import java.util.concurrent.ExecutorService; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; + /** * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 */ @@ -42,7 +46,7 @@ public class ES816HnswBinaryQuantizedVectorsFormat extends AbstractHnswVectorsFo /** Constructs a format using default graph construction parameters */ public ES816HnswBinaryQuantizedVectorsFormat() { - super(NAME); + super(NAME, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, DEFAULT_HNSW_GRAPH_THRESHOLD); } /** @@ -52,7 +56,7 @@ public ES816HnswBinaryQuantizedVectorsFormat() { * @param beamWidth the size of the queue maintained during graph construction. */ public ES816HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) { - super(NAME, maxConn, beamWidth); + super(NAME, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, DEFAULT_HNSW_GRAPH_THRESHOLD); } /** @@ -66,7 +70,7 @@ public ES816HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) { * generated by this format to do the merge */ public ES816HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { - super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec, DEFAULT_HNSW_GRAPH_THRESHOLD); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java index bb11eba56d501..69c5d703e664b 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java @@ -101,8 +101,8 @@ public float score(int i) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { - scorer.scoreBulk( + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + return scorer.scoreBulk( quantized, queryCorrections.lowerInterval(), queryCorrections.upperInterval(), @@ -218,11 +218,11 @@ public float score(int targetOrd) throws IOException { } @Override - public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { if (queryCorrections == null) { throw new IllegalStateException("bulkScore() called before setScoringOrdinal()"); } - scorer.scoreBulk( + return scorer.scoreBulk( quantizedQuery, queryCorrections.lowerInterval(), queryCorrections.upperInterval(), diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java index cba5d62ce3d39..d10821c62ce3f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java @@ -391,6 +391,11 @@ public VectorScorer scorer(float[] query) throws IOException { return quantizedVectorValues.scorer(query); } + @Override + public VectorScorer rescorer(float[] floats) throws IOException { + return in.rescorer(floats); + } + BinarizedByteVectorValues getQuantizedVectorValues() throws IOException { return quantizedVectorValues; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java index 483c717f06846..118bde2d7056a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java @@ -30,6 +30,10 @@ import java.io.IOException; import java.util.concurrent.ExecutorService; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; + /** * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 */ @@ -42,7 +46,7 @@ public class ES818HnswBinaryQuantizedVectorsFormat extends AbstractHnswVectorsFo /** Constructs a format using default graph construction parameters */ public ES818HnswBinaryQuantizedVectorsFormat() { - super(NAME); + super(NAME, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, DEFAULT_HNSW_GRAPH_THRESHOLD); } /** @@ -52,7 +56,7 @@ public ES818HnswBinaryQuantizedVectorsFormat() { * @param beamWidth the size of the queue maintained during graph construction. */ public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) { - super(NAME, maxConn, beamWidth); + super(NAME, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, DEFAULT_HNSW_GRAPH_THRESHOLD); } /** @@ -66,7 +70,7 @@ public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) { * generated by this format to do the merge */ public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { - super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec, DEFAULT_HNSW_GRAPH_THRESHOLD); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java index 06d303b8dc97e..fcfc5a4c4682e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java @@ -11,15 +11,32 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsReader; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.ConjunctionUtils; +import org.apache.lucene.search.DocAndFloatFeatureBuffer; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.MergeReaderWrapper; import java.io.IOException; +import java.util.List; +import java.util.Map; public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat { @@ -61,11 +78,274 @@ public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectI ); // Use mmap for merges and direct I/O for searches. return new MergeReaderWrapper( - new Lucene99FlatVectorsReader(directIOState, vectorsScorer), - new Lucene99FlatVectorsReader(state, vectorsScorer) + new Lucene99FlatBulkScoringVectorsReader( + directIOState, + new Lucene99FlatVectorsReader(directIOState, vectorsScorer), + vectorsScorer, + true + ), + new Lucene99FlatBulkScoringVectorsReader(state, new Lucene99FlatVectorsReader(state, vectorsScorer), vectorsScorer, false) ); } else { - return new Lucene99FlatVectorsReader(state, vectorsScorer); + return new Lucene99FlatBulkScoringVectorsReader( + state, + new Lucene99FlatVectorsReader(state, vectorsScorer), + vectorsScorer, + false + ); + } + } + + static class Lucene99FlatBulkScoringVectorsReader extends FlatVectorsReader { + private final Lucene99FlatVectorsReader inner; + private final SegmentReadState state; + private final boolean forcePreFetching; + + Lucene99FlatBulkScoringVectorsReader( + SegmentReadState state, + Lucene99FlatVectorsReader inner, + FlatVectorsScorer scorer, + boolean forcePreFetching + ) { + super(scorer); + this.inner = inner; + this.state = state; + this.forcePreFetching = forcePreFetching; + } + + @Override + public void close() throws IOException { + inner.close(); + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + return inner.getOffHeapByteSize(fieldInfo); + } + + @Override + public void finishMerge() throws IOException { + inner.finishMerge(); + } + + @Override + public FlatVectorsReader getMergeInstance() throws IOException { + return inner.getMergeInstance(); + } + + @Override + public FlatVectorsScorer getFlatVectorScorer() { + return inner.getFlatVectorScorer(); + } + + @Override + public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { + inner.search(field, target, knnCollector, acceptDocs); + } + + @Override + public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { + inner.search(field, target, knnCollector, acceptDocs); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + return inner.getRandomVectorScorer(field, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + return inner.getRandomVectorScorer(field, target); + } + + @Override + public void checkIntegrity() throws IOException { + inner.checkIntegrity(); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + FloatVectorValues vectorValues = inner.getFloatVectorValues(field); + if (vectorValues == null) { + return null; + } + if (vectorValues.size() == 0) { + return vectorValues; + } + FieldInfo info = state.fieldInfos.fieldInfo(field); + return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), vectorScorer, forcePreFetching); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + return inner.getByteVectorValues(field); + } + + @Override + public long ramBytesUsed() { + return inner.ramBytesUsed(); + } + } + + static class RescorerOffHeapVectorValues extends FloatVectorValues { + private final VectorSimilarityFunction similarityFunction; + private final FloatVectorValues inner; + private final IndexInput inputSlice; + private final FlatVectorsScorer scorer; + private final boolean forcePreFetching; + + RescorerOffHeapVectorValues( + FloatVectorValues inner, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer scorer, + boolean forcePreFetching + ) { + this.inner = inner; + if (inner instanceof HasIndexSlice slice) { + this.inputSlice = slice.getSlice(); + } else { + this.inputSlice = null; + } + this.similarityFunction = similarityFunction; + this.scorer = scorer; + this.forcePreFetching = forcePreFetching; + } + + @Override + public float[] vectorValue(int ord) throws IOException { + return inner.vectorValue(ord); + } + + @Override + public int dimension() { + return inner.dimension(); + } + + @Override + public int size() { + return inner.size(); + } + + @Override + public DocIndexIterator iterator() { + return inner.iterator(); + } + + @Override + public int ordToDoc(int ord) { + return inner.ordToDoc(ord); + } + + @Override + public RescorerOffHeapVectorValues copy() throws IOException { + return new RescorerOffHeapVectorValues(inner.copy(), similarityFunction, scorer, forcePreFetching); + } + + @Override + public VectorScorer rescorer(float[] target) throws IOException { + if (forcePreFetching && inputSlice != null) { + DocIndexIterator indexIterator = inner.iterator(); + RandomVectorScorer randomScorer = scorer.getRandomVectorScorer(similarityFunction, inner, target); + return new PreFetchingFloatBulkVectorScorer(randomScorer, indexIterator, inputSlice, dimension() * Float.BYTES); + } + return inner.rescorer(target); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + return inner.scorer(target); + } + } + + private record PreFetchingFloatBulkVectorScorer( + RandomVectorScorer inner, + KnnVectorValues.DocIndexIterator indexIterator, + IndexInput inputSlice, + int byteSize + ) implements VectorScorer { + + @Override + public float score() throws IOException { + return inner.score(indexIterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return indexIterator; + } + + @Override + public VectorScorer.Bulk bulk(DocIdSetIterator matchingDocs) throws IOException { + DocIdSetIterator conjunctionScorer = matchingDocs == null + ? indexIterator + : ConjunctionUtils.intersectIterators(List.of(matchingDocs, indexIterator)); + if (conjunctionScorer.docID() == -1) { + conjunctionScorer.nextDoc(); + } + return new PrefetchingFloatBulkScorer(inner, inputSlice, byteSize, indexIterator, conjunctionScorer); + } + } + + private static class PrefetchingFloatBulkScorer implements VectorScorer.Bulk { + private static final int BULK_SIZE = 64; + private static final int SCORE_BULK_SIZE = 32; + private final KnnVectorValues.DocIndexIterator indexIterator; + private final DocIdSetIterator matchingDocs; + private final RandomVectorScorer inner; + private final IndexInput inputSlice; + private final int byteSize; + private final int[] docBuffer; + private final float[] scoreBuffer; + + PrefetchingFloatBulkScorer( + RandomVectorScorer fvv, + IndexInput inputSlice, + int byteSize, + KnnVectorValues.DocIndexIterator iterator, + DocIdSetIterator matchingDocs + ) { + this.indexIterator = iterator; + this.matchingDocs = matchingDocs; + this.inner = fvv; + this.inputSlice = inputSlice; + this.docBuffer = new int[BULK_SIZE]; + this.scoreBuffer = new float[BULK_SIZE]; + this.byteSize = byteSize; + } + + @Override + public float nextDocsAndScores(int upTo, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException { + buffer.growNoCopy(BULK_SIZE); + int size = 0; + for (int doc = matchingDocs.docID(); doc < upTo && size < BULK_SIZE; doc = matchingDocs.nextDoc()) { + if (liveDocs == null || liveDocs.get(doc)) { + buffer.docs[size++] = indexIterator.index(); + } + } + for (int j = 0; j < size; j++) { + final long ord = buffer.docs[j]; + inputSlice.prefetch(ord * byteSize, byteSize); + } + final int loopBound = size - (size % SCORE_BULK_SIZE); + int i = 0; + float maxScore = Float.NEGATIVE_INFINITY; + for (; i < loopBound; i += SCORE_BULK_SIZE) { + System.arraycopy(buffer.docs, i, docBuffer, 0, SCORE_BULK_SIZE); + maxScore = Math.max(inner.bulkScore(docBuffer, scoreBuffer, SCORE_BULK_SIZE), maxScore); + System.arraycopy(scoreBuffer, 0, buffer.features, i, SCORE_BULK_SIZE); + } + final int countLeft = size - i; + if (countLeft > 0) { + System.arraycopy(buffer.docs, i, docBuffer, 0, countLeft); + maxScore = Math.max(inner.bulkScore(docBuffer, scoreBuffer, countLeft), maxScore); + System.arraycopy(scoreBuffer, 0, buffer.features, i, countLeft); + } + buffer.size = size; + // fix the docIds in buffer + for (int j = 0; j < size; j++) { + buffer.docs[j] = inner.ordToDoc(buffer.docs[j]); + } + return maxScore; } } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java index 15dce38cb742c..48e30ebe7d3b5 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java @@ -32,16 +32,25 @@ import java.io.IOException; import java.util.concurrent.ExecutorService; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; + public class ES93HnswBinaryQuantizedVectorsFormat extends AbstractHnswVectorsFormat { public static final String NAME = "ES93HnswBinaryQuantizedVectorsFormat"; + /** + * For k=100, we do by default 3x oversampling, asking the HNSW graph to return 300 results. + * So the threshold is set to 300 to match this expected search cost. + */ + public static final int BBQ_HNSW_GRAPH_THRESHOLD = 300; /** The format for storing, reading, merging vectors on disk */ private final FlatVectorsFormat flatVectorsFormat; /** Constructs a format using default graph construction parameters */ public ES93HnswBinaryQuantizedVectorsFormat() { - super(NAME); + super(NAME, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, BBQ_HNSW_GRAPH_THRESHOLD); flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(); } @@ -51,7 +60,7 @@ public ES93HnswBinaryQuantizedVectorsFormat() { * @param useDirectIO whether to use direct IO when reading raw vectors */ public ES93HnswBinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO) { - super(NAME); + super(NAME, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, BBQ_HNSW_GRAPH_THRESHOLD); flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(elementType, useDirectIO); } @@ -68,7 +77,7 @@ public ES93HnswBinaryQuantizedVectorsFormat( DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO ) { - super(NAME, maxConn, beamWidth); + super(NAME, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, BBQ_HNSW_GRAPH_THRESHOLD); flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(elementType, useDirectIO); } @@ -91,7 +100,32 @@ public ES93HnswBinaryQuantizedVectorsFormat( int numMergeWorkers, ExecutorService mergeExec ) { - super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec, BBQ_HNSW_GRAPH_THRESHOLD); + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(elementType, useDirectIO); + } + + /** + * Constructs a format using the given graph construction parameters, scalar quantization, and HNSW graph threshold. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param useDirectIO whether to use direct IO when reading raw vectors + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + * @param hnswGraphThreshold the minimum expected search cost before building an HNSW graph; if negative, use default + */ + public ES93HnswBinaryQuantizedVectorsFormat( + int maxConn, + int beamWidth, + DenseVectorFieldMapper.ElementType elementType, + boolean useDirectIO, + int numMergeWorkers, + ExecutorService mergeExec, + int hnswGraphThreshold + ) { + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec, resolveThreshold(hnswGraphThreshold, BBQ_HNSW_GRAPH_THRESHOLD)); flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(elementType, useDirectIO); } @@ -102,7 +136,15 @@ protected FlatVectorsFormat flatVectorsFormat() { @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec, + hnswGraphThreshold + ); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswScalarQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswScalarQuantizedVectorsFormat.java index f614ebf100b38..b86b59b60c59a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswScalarQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswScalarQuantizedVectorsFormat.java @@ -22,14 +22,17 @@ import java.io.IOException; import java.util.concurrent.ExecutorService; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; + public class ES93HnswScalarQuantizedVectorsFormat extends AbstractHnswVectorsFormat { static final String NAME = "ES93HnswScalarQuantizedVectorsFormat"; - private final FlatVectorsFormat flatVectorFormat; public ES93HnswScalarQuantizedVectorsFormat() { - super(NAME); + super(NAME, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD); flatVectorFormat = new ES93ScalarQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT); } @@ -42,7 +45,7 @@ public ES93HnswScalarQuantizedVectorsFormat( boolean compress, boolean useDirectIO ) { - super(NAME, maxConn, beamWidth); + super(NAME, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD); flatVectorFormat = new ES93ScalarQuantizedVectorsFormat(elementType, confidenceInterval, bits, compress, useDirectIO); } @@ -57,7 +60,30 @@ public ES93HnswScalarQuantizedVectorsFormat( int numMergeWorkers, ExecutorService mergeExec ) { - super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec, ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD); + flatVectorFormat = new ES93ScalarQuantizedVectorsFormat(elementType, confidenceInterval, bits, compress, useDirectIO); + } + + public ES93HnswScalarQuantizedVectorsFormat( + int maxConn, + int beamWidth, + DenseVectorFieldMapper.ElementType elementType, + Float confidenceInterval, + int bits, + boolean compress, + boolean useDirectIO, + int numMergeWorkers, + ExecutorService mergeExec, + int hnswGraphThreshold + ) { + super( + NAME, + maxConn, + beamWidth, + numMergeWorkers, + mergeExec, + resolveThreshold(hnswGraphThreshold, ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD) + ); flatVectorFormat = new ES93ScalarQuantizedVectorsFormat(elementType, confidenceInterval, bits, compress, useDirectIO); } @@ -68,7 +94,15 @@ protected FlatVectorsFormat flatVectorsFormat() { @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec, + hnswGraphThreshold + ); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java index bfa14632b4a4a..0c78af8b7ac10 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java @@ -22,24 +22,33 @@ import java.io.IOException; import java.util.concurrent.ExecutorService; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; + public class ES93HnswVectorsFormat extends AbstractHnswVectorsFormat { static final String NAME = "ES93HnswVectorsFormat"; + /** + * For k=100, we ask by default to search a graph of 100*1.5=150 results. + * So the threshold is set to 150 to match this expected search cost. + */ + public static final int HNSW_GRAPH_THRESHOLD = 150; private final FlatVectorsFormat flatVectorsFormat; public ES93HnswVectorsFormat() { - super(NAME); + super(NAME, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, HNSW_GRAPH_THRESHOLD); flatVectorsFormat = new ES93GenericFlatVectorsFormat(); } public ES93HnswVectorsFormat(DenseVectorFieldMapper.ElementType elementType) { - super(NAME); + super(NAME, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, HNSW_GRAPH_THRESHOLD); flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, false); } public ES93HnswVectorsFormat(int maxConn, int beamWidth, DenseVectorFieldMapper.ElementType elementType) { - super(NAME, maxConn, beamWidth); + super(NAME, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, HNSW_GRAPH_THRESHOLD); flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, false); } @@ -50,7 +59,19 @@ public ES93HnswVectorsFormat( int numMergeWorkers, ExecutorService mergeExec ) { - super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec, HNSW_GRAPH_THRESHOLD); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, false); + } + + public ES93HnswVectorsFormat( + int maxConn, + int beamWidth, + DenseVectorFieldMapper.ElementType elementType, + int numMergeWorkers, + ExecutorService mergeExec, + int hnswGraphThreshold + ) { + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec, resolveThreshold(hnswGraphThreshold, HNSW_GRAPH_THRESHOLD)); flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, false); } @@ -61,7 +82,15 @@ protected FlatVectorsFormat flatVectorsFormat() { @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec, + hnswGraphThreshold + ); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93ScalarQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93ScalarQuantizedVectorsFormat.java index dd8c535d5b895..07fe44233bc67 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93ScalarQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93ScalarQuantizedVectorsFormat.java @@ -9,13 +9,12 @@ package org.elasticsearch.index.codec.vectors.es93; +import org.apache.lucene.backward_codecs.lucene99.Lucene99ScalarQuantizedVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer; -import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; @@ -30,6 +29,7 @@ import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.QuantizedVectorsReader; import org.apache.lucene.util.quantization.ScalarQuantizer; +import org.elasticsearch.index.codec.vectors.Lucene99ScalarQuantizedVectorsWriter; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.simdvec.VectorScorerFactory; import org.elasticsearch.simdvec.VectorSimilarityType; @@ -37,7 +37,7 @@ import java.io.IOException; import java.util.Map; -import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.DYNAMIC_CONFIDENCE_INTERVAL; +import static org.apache.lucene.backward_codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.DYNAMIC_CONFIDENCE_INTERVAL; import static org.elasticsearch.index.codec.vectors.VectorScoringUtils.scoreAndCollectAll; public class ES93ScalarQuantizedVectorsFormat extends FlatVectorsFormat { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java index ecae96b08c576..cbc57abb707c1 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java @@ -177,6 +177,11 @@ public float score() throws IOException { public DocIdSetIterator iterator() { return iterator; } + + @Override + public VectorScorer.Bulk bulk(DocIdSetIterator matchedDocs) { + return VectorScorer.Bulk.fromRandomScorerDense(randomVectorScorer, iterator, matchedDocs); + } }; } } @@ -263,6 +268,11 @@ public float score() throws IOException { public DocIdSetIterator iterator() { return iterator; } + + @Override + public VectorScorer.Bulk bulk(DocIdSetIterator matchedDocs) { + return VectorScorer.Bulk.fromRandomScorerSparse(randomVectorScorer, iterator, matchedDocs); + } }; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedVectorsFormat.java new file mode 100644 index 0000000000000..eacf1641ec456 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedVectorsFormat.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es94; + +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD; + +public class ES94HnswScalarQuantizedVectorsFormat extends AbstractHnswVectorsFormat { + + static final String NAME = "ES94HnswScalarQuantizedVectorsFormat"; + private final FlatVectorsFormat flatVectorFormat; + + public ES94HnswScalarQuantizedVectorsFormat() { + super(NAME, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, HNSW_GRAPH_THRESHOLD); + flatVectorFormat = new ES94ScalarQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT); + } + + public ES94HnswScalarQuantizedVectorsFormat( + int maxConn, + int beamWidth, + DenseVectorFieldMapper.ElementType elementType, + int bits, + boolean useDirectIO + ) { + super(NAME, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, HNSW_GRAPH_THRESHOLD); + flatVectorFormat = new ES94ScalarQuantizedVectorsFormat(elementType, bits, useDirectIO); + } + + public ES94HnswScalarQuantizedVectorsFormat( + int maxConn, + int beamWidth, + DenseVectorFieldMapper.ElementType elementType, + int bits, + boolean useDirectIO, + int numMergeWorkers, + ExecutorService mergeExec + ) { + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec, HNSW_GRAPH_THRESHOLD); + flatVectorFormat = new ES94ScalarQuantizedVectorsFormat(elementType, bits, useDirectIO); + } + + public ES94HnswScalarQuantizedVectorsFormat( + int maxConn, + int beamWidth, + DenseVectorFieldMapper.ElementType elementType, + int bits, + boolean useDirectIO, + int numMergeWorkers, + ExecutorService mergeExec, + int hnswGraphThreshold + ) { + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec, resolveThreshold(hnswGraphThreshold, HNSW_GRAPH_THRESHOLD)); + flatVectorFormat = new ES94ScalarQuantizedVectorsFormat(elementType, bits, useDirectIO); + } + + @Override + protected FlatVectorsFormat flatVectorsFormat() { + return flatVectorFormat; + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec, + hnswGraphThreshold + ); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorFormat.fieldsReader(state)); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedVectorsFormat.java new file mode 100644 index 0000000000000..7caf2a294e59e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedVectorsFormat.java @@ -0,0 +1,190 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es94; + +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorScorer; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsReader; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsWriter; +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; +import org.elasticsearch.index.codec.vectors.es93.ES93FlatVectorScorer; +import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.simdvec.VectorScorerFactory; +import org.elasticsearch.simdvec.VectorSimilarityType; + +import java.io.IOException; + +public class ES94ScalarQuantizedVectorsFormat extends FlatVectorsFormat { + + static final String NAME = "ES94ScalarQuantizedVectorsFormat"; + private static final int ALLOWED_BITS = (1 << 7) | (1 << 4) | (1 << 2) | (1 << 1); + + static final Lucene104ScalarQuantizedVectorScorer flatVectorScorer = new ESQuantizedFlatVectorsScorer(ES93FlatVectorScorer.INSTANCE); + private final FlatVectorsFormat rawVectorFormat; + private final Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding encoding; + + public ES94ScalarQuantizedVectorsFormat() { + this(DenseVectorFieldMapper.ElementType.FLOAT, 7, false); + } + + public ES94ScalarQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType elementType) { + this(elementType, 7, false); + } + + public ES94ScalarQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType elementType, int bits, boolean useDirectIO) { + super(NAME); + if (bits < 1 || bits > 8 || (ALLOWED_BITS & (1 << bits)) == 0) { + throw new IllegalArgumentException("bits must be one of: 1, 2, 4, 7; bits=" + bits); + } + assert elementType != DenseVectorFieldMapper.ElementType.BIT : "BIT should not be used with scalar quantization"; + + this.rawVectorFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO); + this.encoding = switch (bits) { + case 1 -> Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE; + case 2 -> Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.DIBIT_QUERY_NIBBLE; + case 4 -> Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.PACKED_NIBBLE; + case 7 -> Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SEVEN_BIT; + default -> throw new IllegalArgumentException("bits must be one of: 1, 2, 4, 7; bits=" + bits); + }; + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene104ScalarQuantizedVectorsWriter(state, encoding, rawVectorFormat.fieldsWriter(state), flatVectorScorer); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene104ScalarQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), flatVectorScorer); + } + + @Override + public String toString() { + return NAME + + "(name=" + + NAME + + ", encoding=" + + encoding + + ", flatVectorScorer=" + + flatVectorScorer + + ", rawVectorFormat=" + + rawVectorFormat + + ")"; + } + + static final class ESQuantizedFlatVectorsScorer extends Lucene104ScalarQuantizedVectorScorer { + + final FlatVectorsScorer delegate; + final VectorScorerFactory factory; + + ESQuantizedFlatVectorsScorer(FlatVectorsScorer delegate) { + super(delegate); + this.delegate = delegate; + factory = VectorScorerFactory.instance().orElse(null); + } + + @Override + public String toString() { + return "ESQuantizedFlatVectorsScorer(" + "delegate=" + delegate + ", factory=" + factory + ')'; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction sim, KnnVectorValues values) + throws IOException { + if (values instanceof QuantizedByteVectorValues quantizedValues && quantizedValues.getSlice() != null) { + // TODO: optimize int4, 2, and single bit quantization + if (quantizedValues.getScalarEncoding() != Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SEVEN_BIT) { + return super.getRandomVectorScorerSupplier(sim, values); + } + if (factory != null) { + var scorer = factory.getInt7uOSQVectorScorerSupplier( + VectorSimilarityType.of(sim), + quantizedValues.getSlice(), + quantizedValues + ); + if (scorer.isPresent()) { + return scorer.get(); + } + } + } + return super.getRandomVectorScorerSupplier(sim, values); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, KnnVectorValues values, float[] query) + throws IOException { + if (values instanceof QuantizedByteVectorValues quantizedValues && quantizedValues.getSlice() != null) { + // TODO: optimize int4 quantization + if (quantizedValues.getScalarEncoding() != Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SEVEN_BIT) { + return super.getRandomVectorScorer(sim, values, query); + } + if (factory != null) { + OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(sim); + float[] residualScratch = new float[query.length]; + int[] quantizedQuery = new int[query.length]; + var correctiveComponents = scalarQuantizer.scalarQuantize( + query, + residualScratch, + quantizedQuery, + quantizedValues.getScalarEncoding().getQueryBits(), + quantizedValues.getCentroid() + ); + byte[] quantizedQueryBytes = new byte[quantizedQuery.length]; + for (int i = 0; i < quantizedQuery.length; i++) { + quantizedQueryBytes[i] = (byte) quantizedQuery[i]; + } + + var scorer = factory.getInt7uOSQVectorScorer( + sim, + quantizedValues, + quantizedQueryBytes, + correctiveComponents.lowerInterval(), + correctiveComponents.upperInterval(), + correctiveComponents.additionalCorrection(), + correctiveComponents.quantizedComponentSum() + ); + if (scorer.isPresent()) { + return scorer.get(); + } + } + } + return super.getRandomVectorScorer(sim, values, query); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, KnnVectorValues values, byte[] query) + throws IOException { + return super.getRandomVectorScorer(sim, values, query); + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + QuantizedByteVectorValues scoringVectors, + QuantizedByteVectorValues targetVectors + ) { + // TODO improve merge-times for HNSW through off-heap optimized search + return super.getRandomVectorScorerSupplier(similarityFunction, scoringVectors, targetVectors); + } + + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/reflect/VectorsFormatReflectionUtils.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/reflect/VectorsFormatReflectionUtils.java index b42a46fd85ad4..a7e5fe45ffd3f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/reflect/VectorsFormatReflectionUtils.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/reflect/VectorsFormatReflectionUtils.java @@ -10,10 +10,10 @@ package org.elasticsearch.index.codec.vectors.reflect; import org.apache.lucene.codecs.lucene95.HasIndexSlice; -import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.index.codec.vectors.Lucene99ScalarQuantizedVectorsWriter; import org.elasticsearch.simdvec.QuantizedByteVectorValuesAccess; import java.lang.invoke.MethodHandles; diff --git a/server/src/main/java/org/elasticsearch/index/engine/Engine.java b/server/src/main/java/org/elasticsearch/index/engine/Engine.java index f6cb3aa09144d..926d796d6e4e1 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/Engine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/Engine.java @@ -34,9 +34,12 @@ import org.apache.lucene.store.AlreadyClosedException; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.DenseLiveDocs; import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.LiveDocs; import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.SetOnce; +import org.apache.lucene.util.SparseLiveDocs; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.indices.flush.FlushRequest; @@ -325,6 +328,12 @@ protected static ShardFieldStats shardFieldStats(List leaves, // Would prefer to use FixedBitSet#ramBytesUsed() however FixedBits / Bits interface don't expose that. // This simulates FixedBitSet#ramBytesUsed() does: private static long getLiveDocsBytes(Bits liveDocs) { + if (liveDocs instanceof DenseLiveDocs dld) { + return dld.ramBytesUsed(); + } + if (liveDocs instanceof SparseLiveDocs sld) { + return sld.ramBytesUsed(); + } int words = FixedBitSet.bits2words(liveDocs.length()); return ShardFieldStats.FIXED_BITSET_BASE_RAM_BYTES_USED + RamUsageEstimator.alignObjectSize( RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + (long) Long.BYTES * words @@ -332,10 +341,14 @@ private static long getLiveDocsBytes(Bits liveDocs) { } private static boolean validateLiveDocsClass(Bits liveDocs) { + if (liveDocs instanceof LiveDocs) { + return true; + } // These classes are package protected in Lucene and therefor we compare fully qualified classnames as strings here: String fullClassName = liveDocs.getClass().getName(); assert fullClassName.equals("org.apache.lucene.util.FixedBits") - || fullClassName.equals("org.apache.lucene.tests.codecs.asserting.AssertingLiveDocsFormat$AssertingBits") + || fullClassName.contains("org.apache.lucene.tests.codecs.asserting.AssertingLiveDocsFormat$Asserting") + || fullClassName.contains("org.apache.lucene.tests.codecs.asserting.AssertLeafReader$Asserting") : "unexpected class [" + fullClassName + "]"; return true; } diff --git a/server/src/main/java/org/elasticsearch/index/engine/NoOpEngine.java b/server/src/main/java/org/elasticsearch/index/engine/NoOpEngine.java index 044e6f6712c77..a97aa05c1cfb0 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/NoOpEngine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/NoOpEngine.java @@ -31,6 +31,7 @@ import java.io.UncheckedIOException; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutorService; import java.util.function.Function; /** @@ -85,16 +86,31 @@ protected DirectoryReader doOpenIfChanged() { return null; } + @Override + protected DirectoryReader doOpenIfChanged(ExecutorService executorService) { + return null; + } + @Override protected DirectoryReader doOpenIfChanged(IndexCommit commit) { return null; } + @Override + protected DirectoryReader doOpenIfChanged(IndexCommit commit, ExecutorService executorService) { + return null; + } + @Override protected DirectoryReader doOpenIfChanged(IndexWriter writer, boolean applyAllDeletes) { return null; } + @Override + protected DirectoryReader doOpenIfChanged(IndexWriter writer, boolean applyAllDeletes, ExecutorService executorService) { + return null; + } + @Override public long getVersion() { return 0; diff --git a/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java b/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java index 4459abe0a7fc1..99457fd7bfa40 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java @@ -607,8 +607,8 @@ public ShardLongFieldRange getRawFieldRange(String field) throws IOException { return ShardLongFieldRange.of(LongPoint.decodeDimension(minPackedValue, 0), LongPoint.decodeDimension(maxPackedValue, 0)); } - long minValue = DocValuesSkipper.globalMinValue(searcher, field); - long maxValue = DocValuesSkipper.globalMaxValue(searcher, field); + long minValue = DocValuesSkipper.globalMinValue(searcher.getIndexReader(), field); + long maxValue = DocValuesSkipper.globalMaxValue(searcher.getIndexReader(), field); if (minValue == Long.MAX_VALUE && maxValue == Long.MIN_VALUE) { // no skipper return ShardLongFieldRange.EMPTY; diff --git a/server/src/main/java/org/elasticsearch/index/engine/TranslogDirectoryReader.java b/server/src/main/java/org/elasticsearch/index/engine/TranslogDirectoryReader.java index 5704db7663f24..c09fe8c87e5af 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/TranslogDirectoryReader.java +++ b/server/src/main/java/org/elasticsearch/index/engine/TranslogDirectoryReader.java @@ -70,6 +70,7 @@ import java.io.IOException; import java.util.Collections; import java.util.Set; +import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicReference; /** @@ -138,16 +139,31 @@ protected DirectoryReader doOpenIfChanged() { throw unsupported(); } + @Override + protected DirectoryReader doOpenIfChanged(ExecutorService executorService) { + throw unsupported(); + } + @Override protected DirectoryReader doOpenIfChanged(IndexCommit commit) { throw unsupported(); } + @Override + protected DirectoryReader doOpenIfChanged(IndexCommit commit, ExecutorService executorService) { + throw unsupported(); + } + @Override protected DirectoryReader doOpenIfChanged(IndexWriter writer, boolean applyAllDeletes) { throw unsupported(); } + @Override + protected DirectoryReader doOpenIfChanged(IndexWriter writer, boolean applyAllDeletes, ExecutorService executorService) { + throw unsupported(); + } + @Override public long getVersion() { throw unsupported(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenormalizedCosineFloatVectorValues.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenormalizedCosineFloatVectorValues.java index 775168e69e477..20970d9aed5a9 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenormalizedCosineFloatVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenormalizedCosineFloatVectorValues.java @@ -59,6 +59,11 @@ public VectorScorer scorer(float[] floats) throws IOException { return in.scorer(floats); } + @Override + public VectorScorer rescorer(float[] floats) throws IOException { + return in.rescorer(floats); + } + @Override public int ordToDoc(int ord) { return in.ordToDoc(ord); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 2ed4edc0cb901..a3760958d93af 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -41,6 +41,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; @@ -54,9 +55,9 @@ import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93FlatVectorFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.es93.ES93HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat; -import org.elasticsearch.index.codec.vectors.es93.ES93ScalarQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es94.ES94HnswScalarQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es94.ES94ScalarQuantizedVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.BlockLoader; @@ -138,6 +139,9 @@ public class DenseVectorFieldMapper extends FieldMapper { public static final String COSINE_MAGNITUDE_FIELD_SUFFIX = "._magnitude"; public static final int BBQ_MIN_DIMS = 64; + public static final String CONFIDENCE_INTERVAL_DEPRECATION_MESSAGE = + "Parameter [confidence_interval] in [index_options] for dense_vector field [{}] " + + "is deprecated and will be removed in a future version"; private static final int DEFAULT_BBQ_IVF_QUANTIZE_BITS = 1; @@ -407,15 +411,16 @@ private DenseVectorIndexOptions defaultIndexOptions(boolean defaultInt8Hnsw, boo Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, false, - new RescoreVector(DEFAULT_OVERSAMPLE) + new RescoreVector(DEFAULT_OVERSAMPLE), + -1 ); } else if (defaultInt8Hnsw) { return new Int8HnswIndexOptions( Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, - null, false, - null + null, + -1 ); } return null; @@ -1483,10 +1488,24 @@ public final int hashCode() { abstract static class QuantizedIndexOptions extends DenseVectorIndexOptions { final RescoreVector rescoreVector; + final Float confidenceInterval; QuantizedIndexOptions(VectorIndexType type, RescoreVector rescoreVector) { + this(type, rescoreVector, null); + } + + QuantizedIndexOptions(VectorIndexType type, RescoreVector rescoreVector, Float confidenceInterval) { super(type); this.rescoreVector = rescoreVector; + this.confidenceInterval = confidenceInterval; + } + + public Float confidenceInterval() { + return confidenceInterval; + } + + public RescoreVector rescoreVector() { + return rescoreVector; } } @@ -1501,17 +1520,19 @@ public DenseVectorIndexOptions parseIndexOptions( ) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); + Object flatIndexThresholdNode = indexOptionsMap.remove("flat_index_threshold"); Object onDiskRescoreNode = indexOptionsMap.remove("on_disk_rescore"); int m = XContentMapValues.nodeIntegerValue(mNode, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); + int flatIndexThreshold = XContentMapValues.nodeIntegerValue(flatIndexThresholdNode, -1); boolean onDiskRescore = XContentMapValues.nodeBooleanValue(onDiskRescoreNode, false); if (onDiskRescore) { throw new IllegalArgumentException("on_disk_rescore is only supported for indexed and quantized vector types"); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new HnswIndexOptions(m, efConstruction); + return new HnswIndexOptions(m, efConstruction, flatIndexThreshold); } @Override @@ -1534,23 +1555,21 @@ public DenseVectorIndexOptions parseIndexOptions( ) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); - Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); + Float confidenceInterval = parseConfidenceInterval(fieldName, indexOptionsMap, indexVersion); Object onDiskRescoreNode = indexOptionsMap.remove("on_disk_rescore"); + Object flatIndexThresholdNode = indexOptionsMap.remove("flat_index_threshold"); int m = XContentMapValues.nodeIntegerValue(mNode, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); boolean onDiskRescore = XContentMapValues.nodeBooleanValue(onDiskRescoreNode, false); + int flatIndexThreshold = XContentMapValues.nodeIntegerValue(flatIndexThresholdNode, -1); - Float confidenceInterval = null; - if (confidenceIntervalNode != null) { - confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); - } RescoreVector rescoreVector = null; if (hasRescoreIndexVersion(indexVersion)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int8HnswIndexOptions(m, efConstruction, confidenceInterval, onDiskRescore, rescoreVector); + return new Int8HnswIndexOptions(m, efConstruction, onDiskRescore, rescoreVector, flatIndexThreshold, confidenceInterval); } @Override @@ -1572,24 +1591,22 @@ public DenseVectorIndexOptions parseIndexOptions( ) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); - Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); + Float confidenceInterval = parseConfidenceInterval(fieldName, indexOptionsMap, indexVersion); Object onDiskRescoreNode = indexOptionsMap.remove("on_disk_rescore"); + Object flatIndexThresholdNode = indexOptionsMap.remove("flat_index_threshold"); int m = XContentMapValues.nodeIntegerValue(mNode, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); boolean onDiskRescore = XContentMapValues.nodeBooleanValue(onDiskRescoreNode, false); + int flatIndexThreshold = XContentMapValues.nodeIntegerValue(flatIndexThresholdNode, -1); - Float confidenceInterval = null; - if (confidenceIntervalNode != null) { - confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); - } RescoreVector rescoreVector = null; if (hasRescoreIndexVersion(indexVersion)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int4HnswIndexOptions(m, efConstruction, confidenceInterval, onDiskRescore, rescoreVector); + return new Int4HnswIndexOptions(m, efConstruction, onDiskRescore, rescoreVector, flatIndexThreshold, confidenceInterval); } @Override @@ -1638,11 +1655,7 @@ public DenseVectorIndexOptions parseIndexOptions( boolean experimentalFeaturesEnabled ) { Object onDiskRescoreNode = indexOptionsMap.remove("on_disk_rescore"); - Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); - Float confidenceInterval = null; - if (confidenceIntervalNode != null) { - confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); - } + Float confidenceInterval = parseConfidenceInterval(fieldName, indexOptionsMap, indexVersion); RescoreVector rescoreVector = null; if (hasRescoreIndexVersion(indexVersion)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion); @@ -1652,7 +1665,7 @@ public DenseVectorIndexOptions parseIndexOptions( throw new IllegalArgumentException("on_disk_rescore is only supported for indexed and quantized vector types"); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int8FlatIndexOptions(confidenceInterval, rescoreVector); + return new Int8FlatIndexOptions(rescoreVector, confidenceInterval); } @Override @@ -1674,11 +1687,7 @@ public DenseVectorIndexOptions parseIndexOptions( boolean experimentalFeaturesEnabled ) { Object onDiskRescoreNode = indexOptionsMap.remove("on_disk_rescore"); - Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); - Float confidenceInterval = null; - if (confidenceIntervalNode != null) { - confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); - } + Float confidenceInterval = parseConfidenceInterval(fieldName, indexOptionsMap, indexVersion); RescoreVector rescoreVector = null; if (hasRescoreIndexVersion(indexVersion)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion); @@ -1688,7 +1697,7 @@ public DenseVectorIndexOptions parseIndexOptions( throw new IllegalArgumentException("on_disk_rescore is only supported for indexed and quantized vector types"); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int4FlatIndexOptions(confidenceInterval, rescoreVector); + return new Int4FlatIndexOptions(rescoreVector, confidenceInterval); } @Override @@ -1712,10 +1721,12 @@ public DenseVectorIndexOptions parseIndexOptions( Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); Object onDiskRescoreNode = indexOptionsMap.remove("on_disk_rescore"); + Object flatIndexThresholdNode = indexOptionsMap.remove("flat_index_threshold"); int m = XContentMapValues.nodeIntegerValue(mNode, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); boolean onDiskRescore = XContentMapValues.nodeBooleanValue(onDiskRescoreNode, false); + int flatIndexThreshold = XContentMapValues.nodeIntegerValue(flatIndexThresholdNode, -1); RescoreVector rescoreVector = null; if (hasRescoreIndexVersion(indexVersion)) { @@ -1726,7 +1737,7 @@ public DenseVectorIndexOptions parseIndexOptions( } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new BBQHnswIndexOptions(m, efConstruction, onDiskRescore, rescoreVector); + return new BBQHnswIndexOptions(m, efConstruction, onDiskRescore, rescoreVector, flatIndexThreshold); } @Override @@ -1922,23 +1933,24 @@ public String toString() { } static class Int8FlatIndexOptions extends QuantizedIndexOptions { - private final Float confidenceInterval; - - Int8FlatIndexOptions(Float confidenceInterval, RescoreVector rescoreVector) { + Int8FlatIndexOptions(RescoreVector rescoreVector) { super(VectorIndexType.INT8_FLAT, rescoreVector); - this.confidenceInterval = confidenceInterval; + } + + Int8FlatIndexOptions(RescoreVector rescoreVector, Float confidenceInterval) { + super(VectorIndexType.INT8_FLAT, rescoreVector, confidenceInterval); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field("type", type); - if (confidenceInterval != null) { - builder.field("confidence_interval", confidenceInterval); - } if (rescoreVector != null) { rescoreVector.toXContent(builder, params); } + if (confidenceInterval != null) { + builder.field("confidence_interval", confidenceInterval); + } builder.endObject(); return builder; } @@ -1946,18 +1958,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override KnnVectorsFormat getVectorsFormat(ElementType elementType, ExecutorService mergingExecutorService, int numMergeWorkers) { assert elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16; - return new ES93ScalarQuantizedVectorsFormat(elementType, confidenceInterval, 7, false, false); + return new ES94ScalarQuantizedVectorsFormat(elementType, 7, false); } @Override boolean doEquals(DenseVectorIndexOptions o) { Int8FlatIndexOptions that = (Int8FlatIndexOptions) o; - return Objects.equals(confidenceInterval, that.confidenceInterval) && Objects.equals(rescoreVector, that.rescoreVector); + return Objects.equals(rescoreVector, that.rescoreVector) && Objects.equals(confidenceInterval, that.confidenceInterval); } @Override int doHashCode() { - return Objects.hash(confidenceInterval, rescoreVector); + return Objects.hash(rescoreVector, confidenceInterval); } @Override @@ -2020,38 +2032,40 @@ public boolean isFlat() { public static class Int4HnswIndexOptions extends QuantizedIndexOptions { private final int m; private final int efConstruction; - private final float confidenceInterval; private final boolean onDiskRescore; + private final int flatIndexThreshold; + + public Int4HnswIndexOptions(int m, int efConstruction, boolean onDiskRescore, RescoreVector rescoreVector, int flatIndexThreshold) { + this(m, efConstruction, onDiskRescore, rescoreVector, flatIndexThreshold, null); + } public Int4HnswIndexOptions( int m, int efConstruction, - Float confidenceInterval, boolean onDiskRescore, - RescoreVector rescoreVector + RescoreVector rescoreVector, + int flatIndexThreshold, + Float confidenceInterval ) { - super(VectorIndexType.INT4_HNSW, rescoreVector); + super(VectorIndexType.INT4_HNSW, rescoreVector, confidenceInterval); this.m = m; this.efConstruction = efConstruction; - // The default confidence interval for int4 is dynamic quantiles, this provides the best relevancy and is - // effectively required for int4 to behave well across a wide range of data. - this.confidenceInterval = confidenceInterval == null ? 0f : confidenceInterval; this.onDiskRescore = onDiskRescore; + this.flatIndexThreshold = flatIndexThreshold; } @Override public KnnVectorsFormat getVectorsFormat(ElementType elementType, ExecutorService mergingExecutorService, int numMergeWorkers) { assert elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16; - return new ES93HnswScalarQuantizedVectorsFormat( + return new ES94HnswScalarQuantizedVectorsFormat( m, efConstruction, elementType, - confidenceInterval, 4, - true, onDiskRescore, numMergeWorkers, - mergingExecutorService + mergingExecutorService, + flatIndexThreshold ); } @@ -2061,13 +2075,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field("type", type); builder.field("m", m); builder.field("ef_construction", efConstruction); - builder.field("confidence_interval", confidenceInterval); + if (confidenceInterval != null) { + builder.field("confidence_interval", confidenceInterval); + } if (onDiskRescore) { builder.field("on_disk_rescore", true); } if (rescoreVector != null) { rescoreVector.toXContent(builder, params); } + if (flatIndexThreshold >= 0) { + builder.field("flat_index_threshold", flatIndexThreshold); + } builder.endObject(); return builder; } @@ -2077,14 +2096,15 @@ public boolean doEquals(DenseVectorIndexOptions o) { Int4HnswIndexOptions that = (Int4HnswIndexOptions) o; return m == that.m && efConstruction == that.efConstruction - && Objects.equals(confidenceInterval, that.confidenceInterval) && onDiskRescore == that.onDiskRescore - && Objects.equals(rescoreVector, that.rescoreVector); + && Objects.equals(rescoreVector, that.rescoreVector) + && Objects.equals(confidenceInterval, that.confidenceInterval) + && flatIndexThreshold == that.flatIndexThreshold; } @Override public int doHashCode() { - return Objects.hash(m, efConstruction, confidenceInterval, onDiskRescore, rescoreVector); + return Objects.hash(m, efConstruction, onDiskRescore, rescoreVector, confidenceInterval, flatIndexThreshold); } @Override @@ -2092,6 +2112,10 @@ public boolean isFlat() { return false; } + public int flatIndexThreshold() { + return flatIndexThreshold; + } + @Override public String toString() { return "{type=" @@ -2100,12 +2124,12 @@ public String toString() { + m + ", ef_construction=" + efConstruction - + ", confidence_interval=" - + confidenceInterval + ", on_disk_rescore=" + onDiskRescore + ", rescore_vector=" + (rescoreVector == null ? "none" : rescoreVector) + + ", flat_index_threshold=" + + flatIndexThreshold + "}"; } @@ -2115,8 +2139,7 @@ public boolean updatableTo(DenseVectorIndexOptions update) { if (update.type.equals(VectorIndexType.INT4_HNSW)) { Int4HnswIndexOptions int4HnswIndexOptions = (Int4HnswIndexOptions) update; // fewer connections would break assumptions on max number of connections (based on largest previous graph) during merge - // quantization could not behave as expected with different confidence intervals (and quantiles) to be created - updatable = int4HnswIndexOptions.m >= this.m && confidenceInterval == int4HnswIndexOptions.confidenceInterval; + updatable = int4HnswIndexOptions.m >= this.m; } else if (update.type.equals(VectorIndexType.BBQ_HNSW)) { updatable = ((BBQHnswIndexOptions) update).m >= m; } @@ -2125,29 +2148,30 @@ public boolean updatableTo(DenseVectorIndexOptions update) { } static class Int4FlatIndexOptions extends QuantizedIndexOptions { - private final float confidenceInterval; - - Int4FlatIndexOptions(Float confidenceInterval, RescoreVector rescoreVector) { + Int4FlatIndexOptions(RescoreVector rescoreVector) { super(VectorIndexType.INT4_FLAT, rescoreVector); - // The default confidence interval for int4 is dynamic quantiles, this provides the best relevancy and is - // effectively required for int4 to behave well across a wide range of data. - this.confidenceInterval = confidenceInterval == null ? 0f : confidenceInterval; + } + + Int4FlatIndexOptions(RescoreVector rescoreVector, Float confidenceInterval) { + super(VectorIndexType.INT4_FLAT, rescoreVector, confidenceInterval); } @Override public KnnVectorsFormat getVectorsFormat(ElementType elementType, ExecutorService mergingExecutorService, int numMergeWorkers) { assert elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16; - return new ES93ScalarQuantizedVectorsFormat(elementType, confidenceInterval, 4, true, false); + return new ES94ScalarQuantizedVectorsFormat(elementType, 4, false); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field("type", type); - builder.field("confidence_interval", confidenceInterval); if (rescoreVector != null) { rescoreVector.toXContent(builder, params); } + if (confidenceInterval != null) { + builder.field("confidence_interval", confidenceInterval); + } builder.endObject(); return builder; } @@ -2157,12 +2181,12 @@ public boolean doEquals(DenseVectorIndexOptions o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Int4FlatIndexOptions that = (Int4FlatIndexOptions) o; - return Objects.equals(confidenceInterval, that.confidenceInterval) && Objects.equals(rescoreVector, that.rescoreVector); + return Objects.equals(rescoreVector, that.rescoreVector) && Objects.equals(confidenceInterval, that.confidenceInterval); } @Override public int doHashCode() { - return Objects.hash(confidenceInterval, rescoreVector); + return Objects.hash(rescoreVector, confidenceInterval); } @Override @@ -2172,7 +2196,7 @@ public boolean isFlat() { @Override public String toString() { - return "{type=" + type + ", confidence_interval=" + confidenceInterval + ", rescore_vector=" + rescoreVector + "}"; + return "{type=" + type + ", rescore_vector=" + rescoreVector + "}"; } @Override @@ -2190,36 +2214,40 @@ public boolean updatableTo(DenseVectorIndexOptions update) { public static class Int8HnswIndexOptions extends QuantizedIndexOptions { private final int m; private final int efConstruction; - private final Float confidenceInterval; private final boolean onDiskRescore; + private final int flatIndexThreshold; + + public Int8HnswIndexOptions(int m, int efConstruction, boolean onDiskRescore, RescoreVector rescoreVector, int flatIndexThreshold) { + this(m, efConstruction, onDiskRescore, rescoreVector, flatIndexThreshold, null); + } public Int8HnswIndexOptions( int m, int efConstruction, - Float confidenceInterval, boolean onDiskRescore, - RescoreVector rescoreVector + RescoreVector rescoreVector, + int flatIndexThreshold, + Float confidenceInterval ) { - super(VectorIndexType.INT8_HNSW, rescoreVector); + super(VectorIndexType.INT8_HNSW, rescoreVector, confidenceInterval); this.m = m; this.efConstruction = efConstruction; - this.confidenceInterval = confidenceInterval; this.onDiskRescore = onDiskRescore; + this.flatIndexThreshold = flatIndexThreshold; } @Override public KnnVectorsFormat getVectorsFormat(ElementType elementType, ExecutorService mergingExecutorService, int numMergeWorkers) { assert elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16; - return new ES93HnswScalarQuantizedVectorsFormat( + return new ES94HnswScalarQuantizedVectorsFormat( m, efConstruction, elementType, - confidenceInterval, 7, - false, onDiskRescore, numMergeWorkers, - mergingExecutorService + mergingExecutorService, + flatIndexThreshold ); } @@ -2238,6 +2266,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (rescoreVector != null) { rescoreVector.toXContent(builder, params); } + if (flatIndexThreshold >= 0) { + builder.field("flat_index_threshold", flatIndexThreshold); + } builder.endObject(); return builder; } @@ -2249,14 +2280,15 @@ public boolean doEquals(DenseVectorIndexOptions o) { Int8HnswIndexOptions that = (Int8HnswIndexOptions) o; return m == that.m && efConstruction == that.efConstruction - && Objects.equals(confidenceInterval, that.confidenceInterval) && onDiskRescore == that.onDiskRescore - && Objects.equals(rescoreVector, that.rescoreVector); + && Objects.equals(rescoreVector, that.rescoreVector) + && Objects.equals(confidenceInterval, that.confidenceInterval) + && flatIndexThreshold == that.flatIndexThreshold; } @Override public int doHashCode() { - return Objects.hash(m, efConstruction, confidenceInterval, onDiskRescore, rescoreVector); + return Objects.hash(m, efConstruction, onDiskRescore, rescoreVector, confidenceInterval, flatIndexThreshold); } @Override @@ -2272,6 +2304,10 @@ public int efConstruction() { return efConstruction; } + public int flatIndexThreshold() { + return flatIndexThreshold; + } + public Float confidenceInterval() { return confidenceInterval; } @@ -2284,12 +2320,12 @@ public String toString() { + m + ", ef_construction=" + efConstruction - + ", confidence_interval=" - + confidenceInterval + ", on_disk_rescore=" + onDiskRescore + ", rescore_vector=" + (rescoreVector == null ? "none" : rescoreVector) + + ", flat_index_threshold=" + + flatIndexThreshold + "}"; } @@ -2299,11 +2335,7 @@ public boolean updatableTo(DenseVectorIndexOptions update) { if (update.type.equals(this.type)) { Int8HnswIndexOptions int8HnswIndexOptions = (Int8HnswIndexOptions) update; // fewer connections would break assumptions on max number of connections (based on largest previous graph) during merge - // quantization could not behave as expected with different confidence intervals (and quantiles) to be created updatable = int8HnswIndexOptions.m >= this.m; - updatable &= confidenceInterval == null - || int8HnswIndexOptions.confidenceInterval != null - && confidenceInterval.equals(int8HnswIndexOptions.confidenceInterval); } else { updatable = update.type.equals(VectorIndexType.INT4_HNSW) && ((Int4HnswIndexOptions) update).m >= this.m || (update.type.equals(VectorIndexType.BBQ_HNSW) && ((BBQHnswIndexOptions) update).m >= m); @@ -2315,16 +2347,18 @@ public boolean updatableTo(DenseVectorIndexOptions update) { public static class HnswIndexOptions extends DenseVectorIndexOptions { private final int m; private final int efConstruction; + private final int flatIndexThreshold; - HnswIndexOptions(int m, int efConstruction) { + HnswIndexOptions(int m, int efConstruction, int flatIndexThreshold) { super(VectorIndexType.HNSW); this.m = m; this.efConstruction = efConstruction; + this.flatIndexThreshold = flatIndexThreshold; } @Override public KnnVectorsFormat getVectorsFormat(ElementType elementType, ExecutorService mergingExecutorService, int numMergeWorkers) { - return new ES93HnswVectorsFormat(m, efConstruction, elementType, numMergeWorkers, mergingExecutorService); + return new ES93HnswVectorsFormat(m, efConstruction, elementType, numMergeWorkers, mergingExecutorService, flatIndexThreshold); } @Override @@ -2347,6 +2381,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field("type", type); builder.field("m", m); builder.field("ef_construction", efConstruction); + if (flatIndexThreshold >= 0) { + builder.field("flat_index_threshold", flatIndexThreshold); + } builder.endObject(); return builder; } @@ -2356,12 +2393,12 @@ public boolean doEquals(DenseVectorIndexOptions o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; HnswIndexOptions that = (HnswIndexOptions) o; - return m == that.m && efConstruction == that.efConstruction; + return m == that.m && efConstruction == that.efConstruction && flatIndexThreshold == that.flatIndexThreshold; } @Override public int doHashCode() { - return Objects.hash(m, efConstruction); + return Objects.hash(m, efConstruction, flatIndexThreshold); } @Override @@ -2377,9 +2414,21 @@ public int efConstruction() { return efConstruction; } + public int flatIndexThreshold() { + return flatIndexThreshold; + } + @Override public String toString() { - return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + "}"; + return "{type=" + + type + + ", m=" + + m + + ", ef_construction=" + + efConstruction + + ", flat_index_threshold=" + + flatIndexThreshold + + "}"; } } @@ -2387,12 +2436,14 @@ public static class BBQHnswIndexOptions extends QuantizedIndexOptions { private final int m; private final int efConstruction; private final boolean onDiskRescore; + private final int flatIndexThreshold; - public BBQHnswIndexOptions(int m, int efConstruction, boolean onDiskRescore, RescoreVector rescoreVector) { + public BBQHnswIndexOptions(int m, int efConstruction, boolean onDiskRescore, RescoreVector rescoreVector, int flatIndexThreshold) { super(VectorIndexType.BBQ_HNSW, rescoreVector); this.m = m; this.efConstruction = efConstruction; this.onDiskRescore = onDiskRescore; + this.flatIndexThreshold = flatIndexThreshold; } @Override @@ -2404,7 +2455,8 @@ KnnVectorsFormat getVectorsFormat(ElementType elementType, ExecutorService mergi elementType, onDiskRescore, numMergeWorkers, - mergingExecutorService + mergingExecutorService, + flatIndexThreshold ); } @@ -2419,12 +2471,13 @@ boolean doEquals(DenseVectorIndexOptions other) { return m == that.m && efConstruction == that.efConstruction && onDiskRescore == that.onDiskRescore - && Objects.equals(rescoreVector, that.rescoreVector); + && Objects.equals(rescoreVector, that.rescoreVector) + && flatIndexThreshold == that.flatIndexThreshold; } @Override int doHashCode() { - return Objects.hash(m, efConstruction, onDiskRescore, rescoreVector); + return Objects.hash(m, efConstruction, onDiskRescore, rescoreVector, flatIndexThreshold); } @Override @@ -2432,6 +2485,10 @@ public boolean isFlat() { return false; } + public int flatIndexThreshold() { + return flatIndexThreshold; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -2444,6 +2501,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (rescoreVector != null) { rescoreVector.toXContent(builder, params); } + if (flatIndexThreshold >= 0) { + builder.field("flat_index_threshold", flatIndexThreshold); + } builder.endObject(); return builder; } @@ -3483,6 +3543,24 @@ private static DenseVectorIndexOptions parseIndexOptions( return parsedType.parseIndexOptions(fieldName, indexOptionsMap, indexVersion, experimentalFeaturesEnabled); } + private static Float parseConfidenceInterval(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { + Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); + if (confidenceIntervalNode == null) { + return null; + } + float confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); + boolean shouldWarn = indexVersion.onOrAfter(IndexVersions.UPGRADE_TO_LUCENE_10_4_0); + if (shouldWarn) { + deprecationLogger.warn( + DeprecationCategory.MAPPINGS, + "dense_vector_confidence_interval_deprecated", + CONFIDENCE_INTERVAL_DEPRECATION_MESSAGE, + fieldName + ); + } + return confidenceInterval; + } + /** * @return the custom kNN vectors format that is configured for this field or * {@code null} if the default format should be used. @@ -3513,7 +3591,8 @@ public KnnVectorsFormat getKnnVectorsFormatForField( DEFAULT_MAX_CONN, elementType, maxMergingWorkers, - mergingExecutorService + mergingExecutorService, + -1 ); }; } else { diff --git a/server/src/main/java/org/elasticsearch/index/shard/DenseVectorStats.java b/server/src/main/java/org/elasticsearch/index/shard/DenseVectorStats.java index 2c996586b95ac..597bf17d71751 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/DenseVectorStats.java +++ b/server/src/main/java/org/elasticsearch/index/shard/DenseVectorStats.java @@ -156,7 +156,7 @@ private void toXContentWithPerFieldStats(XContentBuilder builder) throws IOExcep builder.startObject(key); for (var eKey : entry.keySet().stream().sorted().toList()) { long value = entry.get(eKey); - assert value > 0L; + assert value >= 0L; builder.humanReadableField(eKey + "_size_bytes", eKey + "_size", ofBytes(value)); } builder.endObject(); diff --git a/server/src/main/java/org/elasticsearch/lucene/comparators/XNumericComparator.java b/server/src/main/java/org/elasticsearch/lucene/comparators/XNumericComparator.java index 4c7f191c0d56f..b9a17beb6061b 100644 --- a/server/src/main/java/org/elasticsearch/lucene/comparators/XNumericComparator.java +++ b/server/src/main/java/org/elasticsearch/lucene/comparators/XNumericComparator.java @@ -22,6 +22,7 @@ import org.apache.lucene.search.Pruning; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.SkipBlockRangeIterator; import org.apache.lucene.util.DocIdSetBuilder; import org.apache.lucene.util.IntsRef; @@ -474,7 +475,7 @@ protected int docCount() { @Override protected void doUpdateCompetitiveIterator() { - competitiveIterator.update(new XSkipBlockRangeIterator(skipper, minValueAsLong, maxValueAsLong)); + competitiveIterator.update(new SkipBlockRangeIterator(skipper, minValueAsLong, maxValueAsLong)); } } } diff --git a/server/src/main/java/org/elasticsearch/lucene/comparators/XSkipBlockRangeIterator.java b/server/src/main/java/org/elasticsearch/lucene/comparators/XSkipBlockRangeIterator.java deleted file mode 100644 index 367fbd702d3c2..0000000000000 --- a/server/src/main/java/org/elasticsearch/lucene/comparators/XSkipBlockRangeIterator.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.lucene.comparators; - -import org.apache.lucene.index.DocValuesSkipper; -import org.apache.lucene.search.AbstractDocIdSetIterator; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.util.FixedBitSet; -import org.elasticsearch.index.IndexVersion; - -import java.io.IOException; - -public class XSkipBlockRangeIterator extends AbstractDocIdSetIterator { - - static { - if (IndexVersion.current().luceneVersion().onOrAfter(org.apache.lucene.util.Version.fromBits(10, 4, 0))) { - throw new IllegalStateException("Remove this class after upgrading to lucene 10.4"); - } - } - - private final DocValuesSkipper skipper; - private final long minValue; - private final long maxValue; - - public XSkipBlockRangeIterator(DocValuesSkipper skipper, long minValue, long maxValue) { - this.skipper = skipper; - this.minValue = minValue; - this.maxValue = maxValue; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - if (target <= skipper.maxDocID(0)) { - // within current block - if (doc > -1) { - // already positioned, so we've checked bounds and know that we're in a matching block - return doc = target; - } else { - // first call, the skipper might already be positioned so ask it to find the next - // matching block (which could be the current one) - skipper.advance(minValue, maxValue); - return doc = Math.max(target, skipper.minDocID(0)); - } - } - // Advance to target - skipper.advance(target); - - // Find the next matching block (could be the current block) - skipper.advance(minValue, maxValue); - - return doc = Math.max(target, skipper.minDocID(0)); - } - - @Override - public long cost() { - return DocIdSetIterator.NO_MORE_DOCS; - } - - @Override - public int docIDRunEnd() throws IOException { - int maxDoc = skipper.maxDocID(0); - int nextLevel = 1; - while (nextLevel < skipper.numLevels() && skipper.minValue(nextLevel) < maxValue && skipper.maxValue(nextLevel) > minValue) { - maxDoc = skipper.maxDocID(nextLevel); - nextLevel++; - } - return maxDoc + 1; - } - - @Override - public void intoBitSet(int upTo, FixedBitSet bitSet, int offset) throws IOException { - while (doc < upTo) { - int end = Math.min(upTo, docIDRunEnd()); - bitSet.set(doc - offset, end - offset); - advance(end); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/search/SearchFeatures.java b/server/src/main/java/org/elasticsearch/search/SearchFeatures.java index a3aa07c3a2cf7..fc98c3a64d065 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchFeatures.java +++ b/server/src/main/java/org/elasticsearch/search/SearchFeatures.java @@ -18,6 +18,7 @@ public final class SearchFeatures implements FeatureSpecification { public static final NodeFeature LUCENE_10_0_0_UPGRADE = new NodeFeature("lucene_10_upgrade"); public static final NodeFeature LUCENE_10_1_0_UPGRADE = new NodeFeature("lucene_10_1_upgrade"); + public static final NodeFeature LUCENE_10_4_0_UPGRADE_TEST = new NodeFeature("lucene_10_4_upgrade"); @Override public Set getFeatures() { @@ -33,6 +34,7 @@ public Set getFeatures() { static final NodeFeature MULTI_MATCH_CHECKS_POSITIONS = new NodeFeature("search.multi.match.checks.positions"); public static final NodeFeature BBQ_HNSW_DEFAULT_INDEXING = new NodeFeature("search.vectors.mappers.default_bbq_hnsw"); public static final NodeFeature SEARCH_WITH_NO_DIMENSIONS_BUGFIX = new NodeFeature("search.vectors.no_dimensions_bugfix"); + public static final NodeFeature HNSW_FLAT_INDEX_THRESHOLD = new NodeFeature("search.vectors.flat_index_threshold"); public static final NodeFeature SEARCH_RESCORE_SCRIPT = new NodeFeature("search.rescore.script"); public static final NodeFeature NEGATIVE_FUNCTION_SCORE_BAD_REQUEST = new NodeFeature("search.negative.function.score.bad.request"); public static final NodeFeature INDICES_BOOST_REMOTE_INDEX_FIX = new NodeFeature("search.indices_boost_remote_index_fix"); @@ -66,6 +68,7 @@ public Set getTestFeatures() { MULTI_MATCH_CHECKS_POSITIONS, BBQ_HNSW_DEFAULT_INDEXING, SEARCH_WITH_NO_DIMENSIONS_BUGFIX, + HNSW_FLAT_INDEX_THRESHOLD, SEARCH_RESCORE_SCRIPT, NEGATIVE_FUNCTION_SCORE_BAD_REQUEST, INDICES_BOOST_REMOTE_INDEX_FIX, @@ -78,6 +81,7 @@ public Set getTestFeatures() { EXPONENTIAL_HISTOGRAM_QUERYDSL_PERCENTILE_RANKS, CLOSING_INVALID_PIT_ID, FUNCTION_SCORE_NAMED_QUERIES, + LUCENE_10_4_0_UPGRADE_TEST, EXPONENTIAL_HISTOGRAM_QUERYDSL_BOXPLOT, EXPONENTIAL_HISTOGRAM_QUERYDSL_RANGE ); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/CoreValuesSourceType.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/CoreValuesSourceType.java index eb2923b3703d3..3f1044bbe3e2a 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/CoreValuesSourceType.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/CoreValuesSourceType.java @@ -318,9 +318,9 @@ public Function roundingPreparer(AggregationContext } else if (dft.hasDocValuesSkipper()) { log.trace("Attempting to apply skipper-based data rounding"); range[0] = dft.resolution() - .roundDownToMillis(DocValuesSkipper.globalMinValue(context.searcher(), fieldContext.field())); + .roundDownToMillis(DocValuesSkipper.globalMinValue(context.searcher().getIndexReader(), fieldContext.field())); range[1] = dft.resolution() - .roundDownToMillis(DocValuesSkipper.globalMaxValue(context.searcher(), fieldContext.field())); + .roundDownToMillis(DocValuesSkipper.globalMaxValue(context.searcher().getIndexReader(), fieldContext.field())); } log.trace("Bounds after index bound date rounding: {}, {}", range[0], range[1]); diff --git a/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java b/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java index b663f8c00f1ed..62bf8b7125bc5 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java @@ -471,6 +471,38 @@ public VectorScorer scorer(byte[] bytes) throws IOException { return new VectorScorer() { private final DocIdSetIterator iterator = exitableIterator(scorerIterator, queryCancellation); + @Override + public VectorScorer.Bulk bulk(DocIdSetIterator matchingDocs) throws IOException { + return scorer.bulk(matchingDocs); + } + + @Override + public float score() throws IOException { + return scorer.score(); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + + @Override + public VectorScorer rescorer(byte[] target) throws IOException { + VectorScorer scorer = in.rescorer(target); + if (scorer == null) { + return null; + } + DocIdSetIterator scorerIterator = scorer.iterator(); + return new VectorScorer() { + private final DocIdSetIterator iterator = exitableIterator(scorerIterator, queryCancellation); + + @Override + public VectorScorer.Bulk bulk(DocIdSetIterator matchingDocs) throws IOException { + return scorer.bulk(matchingDocs); + } + @Override public float score() throws IOException { return scorer.score(); @@ -529,6 +561,38 @@ public VectorScorer scorer(float[] target) throws IOException { return new VectorScorer() { private final DocIdSetIterator iterator = exitableIterator(scorerIterator, queryCancellation); + @Override + public VectorScorer.Bulk bulk(DocIdSetIterator matchingDocs) throws IOException { + return scorer.bulk(matchingDocs); + } + + @Override + public float score() throws IOException { + return scorer.score(); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + + @Override + public VectorScorer rescorer(float[] target) throws IOException { + VectorScorer scorer = in.rescorer(target); + if (scorer == null) { + return null; + } + DocIdSetIterator scorerIterator = scorer.iterator(); + return new VectorScorer() { + private final DocIdSetIterator iterator = exitableIterator(scorerIterator, queryCancellation); + + @Override + public VectorScorer.Bulk bulk(DocIdSetIterator matchingDocs) throws IOException { + return scorer.bulk(matchingDocs); + } + @Override public float score() throws IOException { return scorer.score(); diff --git a/server/src/main/java/org/elasticsearch/search/sort/FieldSortBuilder.java b/server/src/main/java/org/elasticsearch/search/sort/FieldSortBuilder.java index 46c8620d1a6c8..7607772994e32 100644 --- a/server/src/main/java/org/elasticsearch/search/sort/FieldSortBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/sort/FieldSortBuilder.java @@ -15,7 +15,6 @@ import org.apache.lucene.index.MultiTerms; import org.apache.lucene.index.PointValues; import org.apache.lucene.index.Terms; -import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.SortField; import org.apache.lucene.util.NumericUtils; import org.elasticsearch.ElasticsearchParseException; @@ -619,9 +618,8 @@ private static MinAndMax extractNumericMinAndMaxFromSkipper( FieldSortBuilder sortBuilder, String fieldName ) throws IOException { - IndexSearcher searcher = new IndexSearcher(reader); - long min = DocValuesSkipper.globalMinValue(searcher, fieldName); - long max = DocValuesSkipper.globalMaxValue(searcher, fieldName); + long min = DocValuesSkipper.globalMinValue(reader, fieldName); + long max = DocValuesSkipper.globalMaxValue(reader, fieldName); if (min == Long.MIN_VALUE || max == Long.MAX_VALUE || min > max) { // Skipper not available for some segments, or no data return null; diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec index cdb6b947db323..6716dc733825c 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec @@ -3,4 +3,5 @@ org.elasticsearch.index.codec.Elasticsearch816Codec org.elasticsearch.index.codec.Elasticsearch900Codec org.elasticsearch.index.codec.Elasticsearch900Lucene101Codec org.elasticsearch.index.codec.Elasticsearch92Lucene103Codec +org.elasticsearch.index.codec.Elasticsearch93Lucene104Codec org.elasticsearch.index.codec.tsdb.ES93TSDBDefaultCompressionLucene103Codec diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 0dc34ea2e808d..aaf0115b99ee5 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -15,3 +15,5 @@ org.elasticsearch.index.codec.vectors.es93.ES93ScalarQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93HnswScalarQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es94.ES94ScalarQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es94.ES94HnswScalarQuantizedVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzerTests.java b/server/src/test/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzerTests.java index 70fe8dae04374..b4f90f6a5953a 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzerTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzerTests.java @@ -13,7 +13,7 @@ import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.PostingsFormat; -import org.apache.lucene.codecs.lucene103.Lucene103Codec; +import org.apache.lucene.codecs.lucene104.Lucene104Codec; import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat; @@ -58,7 +58,7 @@ import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; -import org.apache.lucene.search.suggest.document.Completion101PostingsFormat; +import org.apache.lucene.search.suggest.document.Completion104PostingsFormat; import org.apache.lucene.search.suggest.document.SuggestField; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FilterDirectory; @@ -345,11 +345,11 @@ public void testTriangle() throws Exception { public void testCompletionField() throws Exception { IndexWriterConfig config = new IndexWriterConfig().setCommitOnClose(true) .setUseCompoundFile(false) - .setCodec(new Lucene103Codec(Lucene103Codec.Mode.BEST_SPEED) { + .setCodec(new Lucene104Codec(Lucene104Codec.Mode.BEST_SPEED) { @Override public PostingsFormat getPostingsFormatForField(String field) { if (field.startsWith("suggest_")) { - return new Completion101PostingsFormat(); + return new Completion104PostingsFormat(); } else { return super.postingsFormat(); } @@ -477,23 +477,23 @@ private static void addFieldsToDoc(Document doc, IndexableField[] fields) { enum CodecMode { BEST_SPEED { @Override - Lucene103Codec.Mode mode() { - return Lucene103Codec.Mode.BEST_SPEED; + Lucene104Codec.Mode mode() { + return Lucene104Codec.Mode.BEST_SPEED; } }, BEST_COMPRESSION { @Override - Lucene103Codec.Mode mode() { - return Lucene103Codec.Mode.BEST_COMPRESSION; + Lucene104Codec.Mode mode() { + return Lucene104Codec.Mode.BEST_COMPRESSION; } }; - abstract Lucene103Codec.Mode mode(); + abstract Lucene104Codec.Mode mode(); } static void indexRandomly(Directory directory, CodecMode codecMode, int numDocs, Consumer addFields) throws IOException { - indexRandomly(directory, new Lucene103Codec(codecMode.mode()), numDocs, addFields); + indexRandomly(directory, new Lucene104Codec(codecMode.mode()), numDocs, addFields); } static void indexRandomly(Directory directory, Codec codec, int numDocs, Consumer addFields) throws IOException { @@ -732,7 +732,7 @@ static void rewriteIndexWithPerFieldCodec(Directory source, CodecMode mode, Dire try (DirectoryReader reader = DirectoryReader.open(source)) { IndexWriterConfig config = new IndexWriterConfig().setSoftDeletesField(Lucene.SOFT_DELETES_FIELD) .setUseCompoundFile(randomBoolean()) - .setCodec(new Lucene103Codec(mode.mode()) { + .setCodec(new Lucene104Codec(mode.mode()) { @Override public PostingsFormat getPostingsFormatForField(String field) { return new ES812PostingsFormat(); @@ -970,7 +970,7 @@ public BytesRef binaryValue() { } } - static class CodecWithBloomFilter extends Lucene103Codec { + static class CodecWithBloomFilter extends Lucene104Codec { private final ES94BloomFilterDocValuesFormat bloomFilterDocValuesFormat; CodecWithBloomFilter(Mode mode, int bloomFilterSize) { diff --git a/server/src/test/java/org/elasticsearch/index/codec/CodecTests.java b/server/src/test/java/org/elasticsearch/index/codec/CodecTests.java index fdc6795c925ec..8d31b2e4bf2c2 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/CodecTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/CodecTests.java @@ -53,7 +53,7 @@ public void testResolveDefaultCodecs() throws Exception { assumeTrue("Only when zstd_stored_fields feature flag is enabled", CodecService.ZSTD_STORED_FIELDS_FEATURE_FLAG); CodecService codecService = createCodecService(); assertThat(codecService.codec("default"), instanceOf(PerFieldMapperCodec.class)); - assertThat(codecService.codec("default"), instanceOf(Elasticsearch92Lucene103Codec.class)); + assertThat(codecService.codec("default"), instanceOf(Elasticsearch93Lucene104Codec.class)); } public void testDefault() throws Exception { diff --git a/server/src/test/java/org/elasticsearch/index/codec/PerFieldMapperCodecTests.java b/server/src/test/java/org/elasticsearch/index/codec/PerFieldMapperCodecTests.java index bccf076f46970..ba87c928cf977 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/PerFieldMapperCodecTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/PerFieldMapperCodecTests.java @@ -10,7 +10,7 @@ package org.elasticsearch.index.codec; import org.apache.lucene.codecs.PostingsFormat; -import org.apache.lucene.codecs.lucene103.Lucene103PostingsFormat; +import org.apache.lucene.codecs.lucene104.Lucene104PostingsFormat; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.common.settings.Settings; @@ -104,7 +104,7 @@ public void testUseBloomFilter() throws IOException { ); assertThat(perFieldMapperCodec.useBloomFilter("another_field"), is(false)); - Class expectedPostingsFormat = timeSeries ? ES812PostingsFormat.class : Lucene103PostingsFormat.class; + Class expectedPostingsFormat = timeSeries ? ES812PostingsFormat.class : Lucene104PostingsFormat.class; assertThat(perFieldMapperCodec.getPostingsFormatForField("another_field"), instanceOf(expectedPostingsFormat)); } @@ -145,7 +145,7 @@ public void testUseEs812PostingsFormat() throws IOException { // standard index mode perFieldMapperCodec = createFormatSupplier(false, false, IndexMode.STANDARD, MAPPING_1); - assertThat(perFieldMapperCodec.getPostingsFormatForField("gauge"), instanceOf(Lucene103PostingsFormat.class)); + assertThat(perFieldMapperCodec.getPostingsFormatForField("gauge"), instanceOf(Lucene104PostingsFormat.class)); perFieldMapperCodec = createFormatSupplier(false, true, IndexMode.STANDARD, MAPPING_1); assertThat(perFieldMapperCodec.getPostingsFormatForField("gauge"), instanceOf(ES812PostingsFormat.class)); diff --git a/server/src/test/java/org/elasticsearch/index/codec/postings/ES812PostingsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/postings/ES812PostingsFormatTests.java index f59e075d6ec5a..bb649c82e7175 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/postings/ES812PostingsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/postings/ES812PostingsFormatTests.java @@ -26,7 +26,7 @@ import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.Impact; +import org.apache.lucene.index.FreqAndNormBuffer; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.store.ByteArrayDataInput; @@ -39,9 +39,6 @@ import org.apache.lucene.tests.util.TestUtil; import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; public class ES812PostingsFormatTests extends BasePostingsFormatTestCase { private final Codec codec = TestUtil.alwaysPostingsFormat(new ES812PostingsFormat()); @@ -78,47 +75,55 @@ public void testFinalBlock() throws Exception { public void testImpactSerialization() throws IOException { // omit norms and omit freqs - doTestImpactSerialization(Collections.singletonList(new Impact(1, 1L))); + FreqAndNormBuffer freqAndNormBuffer = new FreqAndNormBuffer(); + freqAndNormBuffer.add(1, 1L); + doTestImpactSerialization(freqAndNormBuffer); // omit freqs - doTestImpactSerialization(Collections.singletonList(new Impact(1, 42L))); + freqAndNormBuffer = new FreqAndNormBuffer(); + freqAndNormBuffer.add(1, 42L); + doTestImpactSerialization(freqAndNormBuffer); + // omit freqs with very large norms - doTestImpactSerialization(Collections.singletonList(new Impact(1, -100L))); + freqAndNormBuffer = new FreqAndNormBuffer(); + freqAndNormBuffer.add(1, -100); + doTestImpactSerialization(freqAndNormBuffer); // omit norms - doTestImpactSerialization(Collections.singletonList(new Impact(30, 1L))); + freqAndNormBuffer = new FreqAndNormBuffer(); + freqAndNormBuffer.add(30, 1L); + doTestImpactSerialization(freqAndNormBuffer); + // omit norms with large freq - doTestImpactSerialization(Collections.singletonList(new Impact(500, 1L))); + freqAndNormBuffer = new FreqAndNormBuffer(); + freqAndNormBuffer.add(500, 1L); + doTestImpactSerialization(freqAndNormBuffer); // freqs and norms, basic - doTestImpactSerialization( - Arrays.asList( - new Impact(1, 7L), - new Impact(3, 9L), - new Impact(7, 10L), - new Impact(15, 11L), - new Impact(20, 13L), - new Impact(28, 14L) - ) - ); + freqAndNormBuffer = new FreqAndNormBuffer(); + freqAndNormBuffer.add(1, 7L); + freqAndNormBuffer.add(3, 9L); + freqAndNormBuffer.add(7, 10L); + freqAndNormBuffer.add(15, 11L); + freqAndNormBuffer.add(20, 13L); + freqAndNormBuffer.add(28, 14L); + doTestImpactSerialization(freqAndNormBuffer); // freqs and norms, high values - doTestImpactSerialization( - Arrays.asList( - new Impact(2, 2L), - new Impact(10, 10L), - new Impact(12, 50L), - new Impact(50, -100L), - new Impact(1000, -80L), - new Impact(1005, -3L) - ) - ); + freqAndNormBuffer = new FreqAndNormBuffer(); + freqAndNormBuffer.add(2, 2L); + freqAndNormBuffer.add(10, 10L); + freqAndNormBuffer.add(12, 50L); + freqAndNormBuffer.add(50, -100L); + freqAndNormBuffer.add(1000, -80L); + freqAndNormBuffer.add(1005, -3L); + doTestImpactSerialization(freqAndNormBuffer); } - private void doTestImpactSerialization(List impacts) throws IOException { + private void doTestImpactSerialization(FreqAndNormBuffer impacts) throws IOException { CompetitiveImpactAccumulator acc = new CompetitiveImpactAccumulator(); - for (Impact impact : impacts) { - acc.add(impact.freq, impact.norm); + for (int i = 0; i < impacts.size; i++) { + acc.add(impacts.freqs[i], impacts.norms[i]); } try (Directory dir = newDirectory()) { try (IndexOutput out = dir.createOutput("foo", IOContext.DEFAULT)) { @@ -127,11 +132,12 @@ private void doTestImpactSerialization(List impacts) throws IOException try (IndexInput in = dir.openInput("foo", IOContext.DEFAULT)) { byte[] b = new byte[Math.toIntExact(in.length())]; in.readBytes(b, 0, b.length); - List impacts2 = ES812ScoreSkipReader.readImpacts( - new ByteArrayDataInput(b), - new ES812ScoreSkipReader.MutableImpactList() - ); - assertEquals(impacts, impacts2); + FreqAndNormBuffer impacts2 = ES812ScoreSkipReader.readImpacts(new ByteArrayDataInput(b), new FreqAndNormBuffer()); + assertEquals(impacts.size, impacts2.size); + for (int i = 0; i < impacts.size; i++) { + assertEquals(impacts.freqs[i], impacts2.freqs[i]); + assertEquals(impacts.norms[i], impacts2.norms[i]); + } } } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/tsdb/DocValuesCodecDuelTests.java b/server/src/test/java/org/elasticsearch/index/codec/tsdb/DocValuesCodecDuelTests.java index 06c404bd7307a..7dfc78c64241a 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/tsdb/DocValuesCodecDuelTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/tsdb/DocValuesCodecDuelTests.java @@ -26,7 +26,7 @@ import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.BytesRef; -import org.elasticsearch.index.codec.Elasticsearch92Lucene103Codec; +import org.elasticsearch.index.codec.Elasticsearch93Lucene104Codec; import org.elasticsearch.index.codec.tsdb.ES87TSDBDocValuesFormatTests.TestES87TSDBDocValuesFormat; import org.elasticsearch.index.codec.tsdb.es819.ES819TSDBDocValuesFormatTests; import org.elasticsearch.index.codec.tsdb.es819.ES819Version3TSDBDocValuesFormat; @@ -56,7 +56,7 @@ public void testDuel() throws IOException { baselineConfig.setCodec(TestUtil.alwaysDocValuesFormat(new Lucene90DocValuesFormat())); var contenderConf = newIndexWriterConfig(); contenderConf.setMergePolicy(mergePolicy); - Codec codec = new Elasticsearch92Lucene103Codec() { + Codec codec = new Elasticsearch93Lucene104Codec() { final DocValuesFormat docValuesFormat = randomBoolean() ? new ES819Version3TSDBDocValuesFormat( diff --git a/server/src/test/java/org/elasticsearch/index/codec/tsdb/TSDBSyntheticIdPostingsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/tsdb/TSDBSyntheticIdPostingsFormatTests.java index bba3b4591e143..a0ebf980d465a 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/tsdb/TSDBSyntheticIdPostingsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/tsdb/TSDBSyntheticIdPostingsFormatTests.java @@ -9,7 +9,7 @@ package org.elasticsearch.index.codec.tsdb; -import org.apache.lucene.codecs.lucene103.Lucene103Codec; +import org.apache.lucene.codecs.lucene104.Lucene104Codec; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; @@ -483,7 +483,7 @@ private static void runTest(CheckedBiConsumer engine.config().getCodec()); Function docValuesFormatProvider; - if (codec instanceof Elasticsearch92Lucene103Codec es92103codec) { - docValuesFormatProvider = es92103codec::getDocValuesFormatForField; + if (codec instanceof Elasticsearch93Lucene104Codec es93104codec) { + docValuesFormatProvider = es93104codec::getDocValuesFormatForField; } else if (codec instanceof CodecService.DeduplicateFieldInfosCodec deduplicateFieldInfosCodec) { if (deduplicateFieldInfosCodec.delegate() instanceof LegacyPerFieldMapperCodec legacyPerFieldMapperCodec) { docValuesFormatProvider = legacyPerFieldMapperCodec::getDocValuesFormatForField; diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseBFloat16KnnVectorsFormatTestCase.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseBFloat16KnnVectorsFormatTestCase.java index dbe65eac2aedf..0c71c99986c8c 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseBFloat16KnnVectorsFormatTestCase.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseBFloat16KnnVectorsFormatTestCase.java @@ -64,6 +64,11 @@ public abstract class BaseBFloat16KnnVectorsFormatTestCase extends BaseKnnVectorsFormatTestCase { + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + @Override protected VectorEncoding randomVectorEncoding() { return VectorEncoding.FLOAT32; diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseHnswVectorsFormatTestCase.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseHnswVectorsFormatTestCase.java index 097cd70b7344e..2e27a8a78b460 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseHnswVectorsFormatTestCase.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseHnswVectorsFormatTestCase.java @@ -50,6 +50,11 @@ public abstract class BaseHnswVectorsFormatTestCase extends BaseKnnVectorsFormat LogConfigurator.configureESLogging(); // native access requires logging to be initialized } + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + protected abstract KnnVectorsFormat createFormat(); protected abstract KnnVectorsFormat createFormat(int maxConn, int beamWidth); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormatTests.java index 5c9404cc034e4..fb97c69c12e5b 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormatTests.java @@ -35,6 +35,11 @@ public class ES813FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase { LogConfigurator.configureESLogging(); // native access requires logging to be initialized } + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES813FlatRWVectorFormat()); @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813Int8FlatVectorFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813Int8FlatVectorFormatTests.java index 2063dd0d7b2e7..f1b3f11364b76 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813Int8FlatVectorFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813Int8FlatVectorFormatTests.java @@ -35,6 +35,11 @@ public class ES813Int8FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase LogConfigurator.configureESLogging(); // native access requires logging to be initialized } + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES813Int8FlatRWVectorFormat()); @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedRWVectorsFormat.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedRWVectorsFormat.java index 80a291082f2ea..5cc4221bbee1c 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedRWVectorsFormat.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedRWVectorsFormat.java @@ -18,6 +18,14 @@ class ES814HnswScalarQuantizedRWVectorsFormat extends ES814HnswScalarQuantizedVectorsFormat { @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec, + 0 + ); } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java index ecffeb2db1d2d..6046201cbb937 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java @@ -49,6 +49,11 @@ public class ES814HnswScalarQuantizedVectorsFormatTests extends BaseKnnVectorsFo LogConfigurator.configureESLogging(); // native access requires logging to be initialized } + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES814HnswScalarQuantizedRWVectorsFormat()); @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES815HnswBitRWVectorsFormat.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES815HnswBitRWVectorsFormat.java index b0e6b6c7bfc48..7850c275beaf5 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES815HnswBitRWVectorsFormat.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES815HnswBitRWVectorsFormat.java @@ -22,6 +22,14 @@ class ES815HnswBitRWVectorsFormat extends ES815HnswBitVectorsFormat { @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec, + 0 + ); } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java index f58cab557a85f..a69db9c38f79c 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java @@ -61,6 +61,11 @@ public class ES920DiskBBQVectorsFormatTests extends BaseKnnVectorsFormatTestCase LogConfigurator.configureESLogging(); // native access requires logging to be initialized } + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + private KnnVectorsFormat format; private ExecutorService executorService; diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java index 5e92c10019625..c7b4dcb299ddf 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java @@ -67,6 +67,12 @@ public class ESNextDiskBBQVectorsFormatTests extends BaseKnnVectorsFormatTestCas LogConfigurator.loadLog4jPlugins(); LogConfigurator.configureESLogging(); // native access requires logging to be initialized } + + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + KnnVectorsFormat format; @Before diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java index baf12f9e099c6..55925bcd6b40a 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java @@ -75,6 +75,11 @@ public class ES816BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormat LogConfigurator.configureESLogging(); // native access requires logging to be initialized } + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES816BinaryQuantizedRWVectorsFormat()); @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java index d9e4c60033485..eaba727156670 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java @@ -43,6 +43,6 @@ class ES816HnswBinaryQuantizedRWVectorsFormat extends ES816HnswBinaryQuantizedVe @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null); + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null, 0); } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java index c498d6d85fc53..8de9682731a9b 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java @@ -81,6 +81,11 @@ public class ES818BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormat LogConfigurator.configureESLogging(); // native access requires logging to be initialized } + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES818BinaryQuantizedRWVectorsFormat()); @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedRWVectorsFormat.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedRWVectorsFormat.java index 02aa3a70880de..a391047401506 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedRWVectorsFormat.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedRWVectorsFormat.java @@ -32,6 +32,14 @@ public ES818HnswBinaryQuantizedRWVectorsFormat(int maxConn, int beamWidth, int n @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec, + 0 + ); } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java index 10e3ea7fc1088..6c88c393b493e 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java @@ -81,6 +81,11 @@ public class ES93BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatT LogConfigurator.configureESLogging(); // native access requires logging to be initialized } + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + private KnnVectorsFormat format; @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93FlatVectorFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93FlatVectorFormatTests.java index 0d636330c3dfd..403112bd58f8d 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93FlatVectorFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93FlatVectorFormatTests.java @@ -66,6 +66,11 @@ protected VectorEncoding randomVectorEncoding() { }; } + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + @Override protected Codec getCodec() { return TestUtil.alwaysKnnVectorsFormat(new ES93FlatVectorFormat(elementType)); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java index 0220ba831e75d..3476779d373e4 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java @@ -10,7 +10,9 @@ package org.elasticsearch.index.codec.vectors.es93; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.BaseHnswBFloat16VectorsFormatTestCase; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -45,7 +47,8 @@ protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMerge } public void testToString() { - String expected = "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=%s)"; + String expected = + "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=10, beamWidth=20, hnswGraphThreshold=150, flatVectorFormat=%s)"; expected = format(Locale.ROOT, expected, "ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat, format=%s)"); expected = format( Locale.ROOT, @@ -61,10 +64,13 @@ public void testToString() { public void testSimpleOffHeapSize() throws IOException { float[] vector = randomVector(random().nextInt(12, 500)); + // Use threshold=0 to ensure HNSW graph is always built + var format = new ES93HnswVectorsFormat(16, 100, DenseVectorFieldMapper.ElementType.BFLOAT16, 1, null, 0); + IndexWriterConfig config = newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(format)); try (Directory dir = newDirectory()) { testSimpleOffHeapSize( dir, - newIndexWriterConfig(), + config, vector, allOf(aMapWithSize(2), hasEntry("vec", (long) vector.length * BFloat16.BYTES), hasEntry("vex", 1L)) ); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java index 8d09eafb81ca2..cf976d7df04d1 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java @@ -23,6 +23,7 @@ import java.util.concurrent.ExecutorService; import static java.lang.String.format; +import static org.apache.lucene.tests.util.TestUtil.alwaysKnnVectorsFormat; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.equalTo; @@ -62,7 +63,7 @@ protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMerge public void testToString() { String expected = "ES93HnswBinaryQuantizedVectorsFormat(" - + "name=ES93HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=%s)"; + + "name=ES93HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20, hnswGraphThreshold=300, flatVectorFormat=%s)"; expected = format( Locale.ROOT, expected, @@ -96,6 +97,17 @@ public void testSimpleOffHeapSizeMMapDir() throws IOException { public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, boolean expectVecOffHeap) throws IOException { float[] vector = randomVector(random().nextInt(12, 500)); + // Use threshold=0 to ensure HNSW graph is always built + var format = new ES93HnswBinaryQuantizedVectorsFormat( + 16, + 100, + DenseVectorFieldMapper.ElementType.BFLOAT16, + random().nextBoolean(), + 1, + null, + 0 + ); + config.setCodec(alwaysKnnVectorsFormat(format)); var matcher = expectVecOffHeap ? allOf( aMapWithSize(3), diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java index a34df72a11819..2bd68595b6e8c 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java @@ -28,23 +28,23 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; -import java.util.Locale; import java.util.concurrent.ExecutorService; -import static java.lang.String.format; +import static org.apache.lucene.tests.util.TestUtil.alwaysKnnVectorsFormat; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasToString; -import static org.hamcrest.Matchers.oneOf; +import static org.hamcrest.Matchers.is; public class ES93HnswBinaryQuantizedVectorsFormatTests extends BaseHnswVectorsFormatTestCase { @Override protected KnnVectorsFormat createFormat() { - return new ES93HnswBinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, random().nextBoolean()); + return new ES93HnswBinaryQuantizedVectorsFormat(); } @Override @@ -69,26 +69,63 @@ protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMerge ); } - public void testToString() { - String expected = "ES93HnswBinaryQuantizedVectorsFormat(" - + "name=ES93HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=%s)"; - expected = format( - Locale.ROOT, - expected, - "ES93BinaryQuantizedVectorsFormat(name=ES93BinaryQuantizedVectorsFormat, rawVectorFormat=%s," - + " scorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate={}))" + protected KnnVectorsFormat createFormat( + int maxConn, + int beamWidth, + int numMergeWorkers, + ExecutorService service, + int hnswGraphThreshold + ) { + return new ES93HnswBinaryQuantizedVectorsFormat( + maxConn, + beamWidth, + DenseVectorFieldMapper.ElementType.FLOAT, + random().nextBoolean(), + numMergeWorkers, + service, + hnswGraphThreshold ); - expected = format(Locale.ROOT, expected, "ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat, format=%s)"); - expected = format( - Locale.ROOT, - expected, - "Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=ES93FlatVectorScorer(delegate={}))" + } + + public void testToString() { + int hnswGraphThreshold = random().nextInt(1, 1001); + KnnVectorsFormat format = createFormat(10, 20, 1, null, hnswGraphThreshold); + assertThat(format, hasToString(containsString("name=ES93HnswBinaryQuantizedVectorsFormat"))); + assertThat(format, hasToString(containsString("maxConn=10"))); + assertThat(format, hasToString(containsString("beamWidth=20"))); + assertThat(format, hasToString(containsString("hnswGraphThreshold=" + hnswGraphThreshold))); + assertThat(format, hasToString(containsString("ES93BinaryQuantizedVectorsFormat"))); + assertThat(format, hasToString(containsString("ES93GenericFlatVectorsFormat"))); + assertThat(format, hasToString(containsString("Lucene99FlatVectorsFormat"))); + assertThat(format, hasToString(containsString("ES93FlatVectorScorer"))); + } + + public void testDefaultHnswGraphThreshold() { + KnnVectorsFormat format = createFormat(16, 100); + assertThat( + format, + hasToString(containsString("hnswGraphThreshold=" + ES93HnswBinaryQuantizedVectorsFormat.BBQ_HNSW_GRAPH_THRESHOLD)) ); - String defaultScorer = expected.replaceAll("\\{}", "DefaultFlatVectorScorer()"); - String memSegScorer = expected.replaceAll("\\{}", "Lucene99MemorySegmentFlatVectorsScorer()"); + } + + public void testHnswGraphThresholdWithCustomValue() { + int customThreshold = random().nextInt(1, 1001); + KnnVectorsFormat format = createFormat(16, 100, 1, null, customThreshold); + assertThat(format, hasToString(containsString("hnswGraphThreshold=" + customThreshold))); + } + + public void testHnswGraphThresholdWithZeroValue() { + // When threshold is 0, hnswGraphThreshold is omitted from toString (always build graph) + KnnVectorsFormat format = createFormat(16, 100, 1, null, 0); + assertThat(format.toString().contains("hnswGraphThreshold"), is(false)); + } - KnnVectorsFormat format = createFormat(10, 20, 1, null); - assertThat(format, hasToString(oneOf(defaultScorer, memSegScorer))); + public void testHnswGraphThresholdWithNegativeValueFallsBackToDefault() { + KnnVectorsFormat format = createFormat(16, 100, 1, null, -1); + assertThat( + format, + hasToString(containsString("hnswGraphThreshold=" + ES93HnswBinaryQuantizedVectorsFormat.BBQ_HNSW_GRAPH_THRESHOLD)) + ); } public void testSimpleOffHeapSize() throws IOException { @@ -105,6 +142,17 @@ public void testSimpleOffHeapSizeMMapDir() throws IOException { public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, boolean expectVecOffHeap) throws IOException { float[] vector = randomVector(random().nextInt(12, 500)); + // Use threshold=0 to ensure HNSW graph is always built + var format = new ES93HnswBinaryQuantizedVectorsFormat( + 16, + 100, + DenseVectorFieldMapper.ElementType.FLOAT, + random().nextBoolean(), + 1, + null, + 0 + ); + config.setCodec(alwaysKnnVectorsFormat(format)); var matcher = expectVecOffHeap ? allOf( aMapWithSize(3), diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java index b54db35d77273..69bbb2786a5b9 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java @@ -29,9 +29,11 @@ import java.io.IOException; import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasToString; public class ES93HnswBitVectorsFormatTests extends BaseKnnBitVectorsFormatTestCase { @@ -45,9 +47,20 @@ public void init() { similarityFunction = VectorSimilarityFunction.EUCLIDEAN; } + public void testToString() { + var format = new ES93HnswVectorsFormat(10, 20, DenseVectorFieldMapper.ElementType.BIT); + assertThat(format, hasToString(containsString("name=ES93HnswVectorsFormat"))); + assertThat(format, hasToString(containsString("maxConn=10"))); + assertThat(format, hasToString(containsString("beamWidth=20"))); + assertThat(format, hasToString(containsString("hnswGraphThreshold=" + ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD))); + } + public void testSimpleOffHeapSize() throws IOException { byte[] vector = randomVector(random().nextInt(12, 500)); - try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + // Use threshold=0 to ensure HNSW graph is always built + var format = new ES93HnswVectorsFormat(16, 100, DenseVectorFieldMapper.ElementType.BIT, 1, null, 0); + var config = newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(format)); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, config)) { Document doc = new Document(); doc.add(new KnnByteVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN)); w.addDocument(doc); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswScalarQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswScalarQuantizedBFloat16VectorsFormatTests.java index 3c8bb8cc25d97..66670b3e1ed57 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswScalarQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswScalarQuantizedBFloat16VectorsFormatTests.java @@ -10,7 +10,9 @@ package org.elasticsearch.index.codec.vectors.es93; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.BaseHnswBFloat16VectorsFormatTestCase; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -19,21 +21,21 @@ import java.io.IOException; import java.util.concurrent.ExecutorService; -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasToString; public class ES93HnswScalarQuantizedBFloat16VectorsFormatTests extends BaseHnswBFloat16VectorsFormatTestCase { @Override protected KnnVectorsFormat createFormat() { return new ES93HnswScalarQuantizedVectorsFormat( - DEFAULT_MAX_CONN, - DEFAULT_BEAM_WIDTH, + 16, + 100, DenseVectorFieldMapper.ElementType.BFLOAT16, null, 7, @@ -75,12 +77,34 @@ public void testSingleVectorCase() throws Exception { throw new AssumptionViolatedException("Scalar quantization changes the score significantly for MAXIMUM_INNER_PRODUCT"); } + public void testToString() { + KnnVectorsFormat format = createFormat(10, 20, 1, null); + assertThat(format, hasToString(containsString("name=ES93HnswScalarQuantizedVectorsFormat"))); + assertThat(format, hasToString(containsString("maxConn=10"))); + assertThat(format, hasToString(containsString("beamWidth=20"))); + assertThat(format, hasToString(containsString("hnswGraphThreshold=" + ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD))); + } + public void testSimpleOffHeapSize() throws IOException { float[] vector = randomVector(random().nextInt(12, 500)); + // Use threshold=0 to ensure HNSW graph is always built + var format = new ES93HnswScalarQuantizedVectorsFormat( + 16, + 100, + DenseVectorFieldMapper.ElementType.BFLOAT16, + null, + 7, + false, + random().nextBoolean(), + 1, + null, + 0 + ); + IndexWriterConfig config = newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(format)); try (Directory dir = newDirectory()) { testSimpleOffHeapSize( dir, - newIndexWriterConfig(), + config, vector, allOf( aMapWithSize(3), diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswScalarQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswScalarQuantizedVectorsFormatTests.java index d70eb94d3227e..cd4c8abb6005e 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswScalarQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswScalarQuantizedVectorsFormatTests.java @@ -10,7 +10,9 @@ package org.elasticsearch.index.codec.vectors.es93; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.index.codec.vectors.BaseHnswVectorsFormatTestCase; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.junit.AssumptionViolatedException; @@ -18,21 +20,27 @@ import java.io.IOException; import java.util.concurrent.ExecutorService; -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasToString; +import static org.hamcrest.Matchers.is; public class ES93HnswScalarQuantizedVectorsFormatTests extends BaseHnswVectorsFormatTestCase { @Override protected KnnVectorsFormat createFormat() { + return new ES93HnswScalarQuantizedVectorsFormat(); + } + + @Override + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { return new ES93HnswScalarQuantizedVectorsFormat( - DEFAULT_MAX_CONN, - DEFAULT_BEAM_WIDTH, + maxConn, + beamWidth, DenseVectorFieldMapper.ElementType.FLOAT, null, 7, @@ -42,7 +50,7 @@ protected KnnVectorsFormat createFormat() { } @Override - protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { return new ES93HnswScalarQuantizedVectorsFormat( maxConn, beamWidth, @@ -50,12 +58,19 @@ protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { null, 7, false, - random().nextBoolean() + random().nextBoolean(), + numMergeWorkers, + service ); } - @Override - protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { + protected KnnVectorsFormat createFormat( + int maxConn, + int beamWidth, + int numMergeWorkers, + ExecutorService service, + int hnswGraphThreshold + ) { return new ES93HnswScalarQuantizedVectorsFormat( maxConn, beamWidth, @@ -65,10 +80,42 @@ protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMerge false, random().nextBoolean(), numMergeWorkers, - service + service, + hnswGraphThreshold ); } + public void testDefaultHnswGraphThreshold() { + KnnVectorsFormat format = createFormat(16, 100); + assertThat(format, hasToString(containsString("hnswGraphThreshold=" + ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD))); + } + + public void testHnswGraphThresholdWithCustomValue() { + int customThreshold = random().nextInt(1, 1001); + KnnVectorsFormat format = createFormat(16, 100, 1, null, customThreshold); + assertThat(format, hasToString(containsString("hnswGraphThreshold=" + customThreshold))); + } + + public void testHnswGraphThresholdWithZeroValue() { + // When threshold is 0, hnswGraphThreshold is omitted from toString (always build graph) + KnnVectorsFormat format = createFormat(16, 100, 1, null, 0); + assertThat(format.toString().contains("hnswGraphThreshold"), is(false)); + } + + public void testHnswGraphThresholdWithNegativeValueFallsBackToDefault() { + KnnVectorsFormat format = createFormat(16, 100, 1, null, -1); + assertThat(format, hasToString(containsString("hnswGraphThreshold=" + ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD))); + } + + public void testToString() { + int hnswGraphThreshold = random().nextInt(1, 1001); + KnnVectorsFormat format = createFormat(10, 20, 1, null, hnswGraphThreshold); + assertThat(format, hasToString(containsString("name=ES93HnswScalarQuantizedVectorsFormat"))); + assertThat(format, hasToString(containsString("maxConn=10"))); + assertThat(format, hasToString(containsString("beamWidth=20"))); + assertThat(format, hasToString(containsString("hnswGraphThreshold=" + hnswGraphThreshold))); + } + @Override public void testSingleVectorCase() throws Exception { throw new AssumptionViolatedException("Scalar quantization changes the score significantly for MAXIMUM_INNER_PRODUCT"); @@ -76,10 +123,24 @@ public void testSingleVectorCase() throws Exception { public void testSimpleOffHeapSize() throws IOException { float[] vector = randomVector(random().nextInt(12, 500)); + // Use threshold=0 to ensure HNSW graph is always built + var format = new ES93HnswScalarQuantizedVectorsFormat( + 16, + 100, + DenseVectorFieldMapper.ElementType.FLOAT, + null, + 7, + false, + random().nextBoolean(), + 1, + null, + 0 + ); + IndexWriterConfig config = newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(format)); try (Directory dir = newDirectory()) { testSimpleOffHeapSize( dir, - newIndexWriterConfig(), + config, vector, allOf( aMapWithSize(3), diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java index 2e3d32e016c6d..03e97e5a7fa03 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java @@ -10,7 +10,9 @@ package org.elasticsearch.index.codec.vectors.es93; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.index.codec.vectors.BaseHnswVectorsFormatTestCase; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -21,6 +23,7 @@ import static java.lang.String.format; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasToString; import static org.hamcrest.Matchers.is; @@ -43,24 +46,69 @@ protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMerge return new ES93HnswVectorsFormat(maxConn, beamWidth, DenseVectorFieldMapper.ElementType.FLOAT, numMergeWorkers, service); } + protected KnnVectorsFormat createFormat( + int maxConn, + int beamWidth, + int numMergeWorkers, + ExecutorService service, + int hnswGraphThreshold + ) { + return new ES93HnswVectorsFormat( + maxConn, + beamWidth, + DenseVectorFieldMapper.ElementType.FLOAT, + numMergeWorkers, + service, + hnswGraphThreshold + ); + } + + public void testDefaultHnswGraphThreshold() { + KnnVectorsFormat format = createFormat(16, 100); + assertThat(format, hasToString(containsString("hnswGraphThreshold=" + ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD))); + } + + public void testHnswGraphThresholdWithCustomValue() { + int customThreshold = random().nextInt(1, 1001); + KnnVectorsFormat format = createFormat(16, 100, 1, null, customThreshold); + assertThat(format, hasToString(containsString("hnswGraphThreshold=" + customThreshold))); + } + + public void testHnswGraphThresholdWithZeroValue() { + // When threshold is 0, hnswGraphThreshold is omitted from toString (always build graph) + KnnVectorsFormat format = createFormat(16, 100, 1, null, 0); + assertThat(format.toString().contains("hnswGraphThreshold"), is(false)); + } + + public void testHnswGraphThresholdWithNegativeValueFallsBackToDefault() { + KnnVectorsFormat format = createFormat(16, 100, 1, null, -1); + assertThat(format, hasToString(containsString("hnswGraphThreshold=" + ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD))); + } + public void testToString() { - String expected = "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=%s)"; + int hnswGraphThreshold = random().nextInt(1, 1001); + String expected = "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=10, beamWidth=20, hnswGraphThreshold=" + + hnswGraphThreshold + + ", flatVectorFormat=%s)"; expected = format(Locale.ROOT, expected, "ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat, format=%s)"); expected = format(Locale.ROOT, expected, "Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=%s)"); expected = format(Locale.ROOT, expected, "ES93FlatVectorScorer(delegate=%s)"); String defaultScorer = format(Locale.ROOT, expected, "DefaultFlatVectorScorer()"); String memSegScorer = format(Locale.ROOT, expected, "Lucene99MemorySegmentFlatVectorsScorer()"); - KnnVectorsFormat format = createFormat(10, 20, 1, null); + KnnVectorsFormat format = createFormat(10, 20, 1, null, hnswGraphThreshold); assertThat(format, hasToString(is(oneOf(defaultScorer, memSegScorer)))); } public void testSimpleOffHeapSize() throws IOException { float[] vector = randomVector(random().nextInt(12, 500)); + // Use threshold=0 to ensure HNSW graph is always built + var format = new ES93HnswVectorsFormat(16, 100, DenseVectorFieldMapper.ElementType.FLOAT, 1, null, 0); + IndexWriterConfig config = newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(format)); try (Directory dir = newDirectory()) { testSimpleOffHeapSize( dir, - newIndexWriterConfig(), + config, vector, allOf(aMapWithSize(2), hasEntry("vec", (long) vector.length * Float.BYTES), hasEntry("vex", 1L)) ); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93ScalarQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93ScalarQuantizedVectorsFormatTests.java index 5b6b381945c41..87ea73b01b316 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93ScalarQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93ScalarQuantizedVectorsFormatTests.java @@ -41,6 +41,11 @@ public class ES93ScalarQuantizedVectorsFormatTests extends BaseKnnVectorsFormatT LogConfigurator.configureESLogging(); // native access requires logging to be initialized } + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + @Override protected Codec getCodec() { return TestUtil.alwaysKnnVectorsFormat(new ES93ScalarQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT)); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..99d5eac58f12c --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedBFloat16VectorsFormatTests.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es94; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.index.codec.vectors.BFloat16; +import org.elasticsearch.index.codec.vectors.BaseHnswBFloat16VectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.junit.AssumptionViolatedException; +import org.junit.Before; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.elasticsearch.test.ESTestCase.randomFrom; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasEntry; + +public class ES94HnswScalarQuantizedBFloat16VectorsFormatTests extends BaseHnswBFloat16VectorsFormatTestCase { + + private int bits; + + @Before + @Override + public void setUp() throws Exception { + bits = randomFrom(1, 2, 4, 7); + super.setUp(); + } + + @Override + protected KnnVectorsFormat createFormat() { + return new ES94HnswScalarQuantizedVectorsFormat( + DEFAULT_MAX_CONN, + DEFAULT_BEAM_WIDTH, + DenseVectorFieldMapper.ElementType.BFLOAT16, + bits, + false + ); + } + + @Override + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { + return new ES94HnswScalarQuantizedVectorsFormat(maxConn, beamWidth, DenseVectorFieldMapper.ElementType.BFLOAT16, bits, false); + } + + @Override + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { + return new ES94HnswScalarQuantizedVectorsFormat( + maxConn, + beamWidth, + DenseVectorFieldMapper.ElementType.BFLOAT16, + bits, + false, + numMergeWorkers, + service + ); + } + + @Override + public void testSingleVectorCase() throws Exception { + throw new AssumptionViolatedException("Scalar quantization changes the score significantly for MAXIMUM_INNER_PRODUCT"); + } + + public void testSimpleOffHeapSize() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + // Use threshold=0 to ensure HNSW graph is always built, but keep assertion tolerant to implementation details. + KnnVectorsFormat format = new ES94HnswScalarQuantizedVectorsFormat( + 16, + 100, + DenseVectorFieldMapper.ElementType.BFLOAT16, + bits, + false, + 1, + null, + 0 + ); + var config = newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(format)); + try (Directory dir = newDirectory()) { + testSimpleOffHeapSize( + dir, + config, + vector, + allOf( + aMapWithSize(3), + hasEntry("vec", (long) vector.length * BFloat16.BYTES), + hasEntry(equalTo("vex"), greaterThanOrEqualTo(0L)), + hasEntry(equalTo("veq"), greaterThan(0L)) + ) + ); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedVectorsFormatTests.java new file mode 100644 index 0000000000000..4ecfe7e6febf0 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedVectorsFormatTests.java @@ -0,0 +1,163 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es94; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.index.codec.vectors.BaseHnswVectorsFormatTestCase; +import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.junit.AssumptionViolatedException; +import org.junit.Before; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.elasticsearch.test.ESTestCase.randomFrom; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasToString; +import static org.hamcrest.Matchers.is; + +public class ES94HnswScalarQuantizedVectorsFormatTests extends BaseHnswVectorsFormatTestCase { + + private int bits; + + @Before + @Override + public void setUp() throws Exception { + bits = randomFrom(1, 2, 4, 7); + super.setUp(); + } + + @Override + protected KnnVectorsFormat createFormat() { + return new ES94HnswScalarQuantizedVectorsFormat( + DEFAULT_MAX_CONN, + DEFAULT_BEAM_WIDTH, + DenseVectorFieldMapper.ElementType.FLOAT, + bits, + false + ); + } + + @Override + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { + return new ES94HnswScalarQuantizedVectorsFormat(maxConn, beamWidth, DenseVectorFieldMapper.ElementType.FLOAT, bits, false); + } + + @Override + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { + return new ES94HnswScalarQuantizedVectorsFormat( + maxConn, + beamWidth, + DenseVectorFieldMapper.ElementType.FLOAT, + bits, + false, + numMergeWorkers, + service + ); + } + + @Override + public void testSingleVectorCase() throws Exception { + throw new AssumptionViolatedException("Scalar quantization changes the score significantly for MAXIMUM_INNER_PRODUCT"); + } + + public void testSimpleOffHeapSize() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + // Use threshold=0 to ensure HNSW graph is always built, but keep assertion tolerant to implementation details. + KnnVectorsFormat format = createFormat(16, 100, 1, null, 0); + var config = newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(format)); + try (Directory dir = newDirectory()) { + testSimpleOffHeapSize( + dir, + config, + vector, + allOf( + aMapWithSize(3), + hasEntry("vec", (long) vector.length * Float.BYTES), + hasEntry(equalTo("vex"), greaterThanOrEqualTo(0L)), + hasEntry(equalTo("veq"), greaterThan(0L)) + ) + ); + } + } + + public void testToString() { + KnnVectorsFormat format = new ES94HnswScalarQuantizedVectorsFormat(10, 20, DenseVectorFieldMapper.ElementType.FLOAT, 2, false); + assertThat( + format, + hasToString( + is( + "ES94HnswScalarQuantizedVectorsFormat(name=ES94HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, " + + "hnswGraphThreshold=" + + ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD + + ", flatVectorFormat=ES94ScalarQuantizedVectorsFormat(" + + "name=ES94ScalarQuantizedVectorsFormat, encoding=DIBIT_QUERY_NIBBLE, " + + "flatVectorScorer=" + + ES94ScalarQuantizedVectorsFormat.flatVectorScorer + + ", rawVectorFormat=" + + new ES93GenericFlatVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, false) + + "))" + ) + ) + ); + } + + protected KnnVectorsFormat createFormat( + int maxConn, + int beamWidth, + int numMergeWorkers, + ExecutorService service, + int hnswGraphThreshold + ) { + return new ES94HnswScalarQuantizedVectorsFormat( + maxConn, + beamWidth, + DenseVectorFieldMapper.ElementType.FLOAT, + bits, + false, + numMergeWorkers, + service, + hnswGraphThreshold + ); + } + + public void testDefaultHnswGraphThreshold() { + KnnVectorsFormat format = createFormat(16, 100); + assertThat(format.toString().contains("hnswGraphThreshold=" + ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD), is(true)); + } + + public void testHnswGraphThresholdWithCustomValue() { + int customThreshold = random().nextInt(1, 1001); + KnnVectorsFormat format = createFormat(16, 100, 1, null, customThreshold); + assertThat(format.toString().contains("hnswGraphThreshold=" + customThreshold), is(true)); + } + + public void testHnswGraphThresholdWithZeroValue() { + // When threshold is 0, hnswGraphThreshold is omitted from toString (always build graph) + KnnVectorsFormat format = createFormat(16, 100, 1, null, 0); + assertThat(format.toString().contains("hnswGraphThreshold"), is(false)); + } + + public void testHnswGraphThresholdWithNegativeValueFallsBackToDefault() { + KnnVectorsFormat format = createFormat(16, 100, 1, null, -1); + assertThat(format.toString().contains("hnswGraphThreshold=" + ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD), is(true)); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..e65e17ffd48f2 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedBFloat16VectorsFormatTests.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es94; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BFloat16; +import org.elasticsearch.index.codec.vectors.BaseBFloat16KnnVectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.junit.AssumptionViolatedException; +import org.junit.Before; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.elasticsearch.test.ESTestCase.randomFrom; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasEntry; + +public class ES94ScalarQuantizedBFloat16VectorsFormatTests extends BaseBFloat16KnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + private KnnVectorsFormat format; + + @Before + @Override + public void setUp() throws Exception { + int bits = randomFrom(1, 2, 4, 7); + format = new ES94ScalarQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.BFLOAT16, bits, false); + super.setUp(); + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + public void testSearchWithVisitedLimit() { + throw new AssumptionViolatedException("requires graph vector codec"); + } + + public void testSimpleOffHeapSize() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + assertThat(offHeap, aMapWithSize(2)); + assertThat(offHeap, hasEntry("vec", (long) vector.length * BFloat16.BYTES)); + assertThat(offHeap, hasEntry(equalTo("veq"), greaterThan(0L))); + } + } + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedVectorsFormatTests.java new file mode 100644 index 0000000000000..6ac9a0e305b84 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedVectorsFormatTests.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es94; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.junit.AssumptionViolatedException; +import org.junit.Before; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.elasticsearch.test.ESTestCase.randomFrom; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.is; + +public class ES94ScalarQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + + private KnnVectorsFormat format; + + @Before + @Override + public void setUp() throws Exception { + int bits = randomFrom(1, 2, 4, 7); + format = new ES94ScalarQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, bits, false); + super.setUp(); + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + public void testSearchWithVisitedLimit() { + throw new AssumptionViolatedException("requires graph vector codec"); + } + + public void testToString() { + var format = new ES94ScalarQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, 4, false); + String expected = "ES94ScalarQuantizedVectorsFormat(name=ES94ScalarQuantizedVectorsFormat, encoding=PACKED_NIBBLE, " + + "flatVectorScorer=" + + ES94ScalarQuantizedVectorsFormat.flatVectorScorer + + ", rawVectorFormat=" + + new ES93GenericFlatVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, false) + + ")"; + assertThat(format.toString(), is(expected)); + } + + public void testSimpleOffHeapSize() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + assertThat(offHeap, aMapWithSize(2)); + assertThat(offHeap, hasEntry("vec", (long) vector.length * Float.BYTES)); + assertThat(offHeap, hasEntry(equalTo("veq"), greaterThan(0L))); + } + } + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/zstd/StoredFieldCodecDuelTests.java b/server/src/test/java/org/elasticsearch/index/codec/zstd/StoredFieldCodecDuelTests.java index ffcfe22c1b6ab..dad786a95efb8 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/zstd/StoredFieldCodecDuelTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/zstd/StoredFieldCodecDuelTests.java @@ -10,7 +10,7 @@ package org.elasticsearch.index.codec.zstd; import org.apache.lucene.codecs.Codec; -import org.apache.lucene.codecs.lucene103.Lucene103Codec; +import org.apache.lucene.codecs.lucene104.Lucene104Codec; import org.apache.lucene.document.Document; import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DirectoryReader; @@ -35,13 +35,13 @@ public class StoredFieldCodecDuelTests extends ESTestCase { private static final String DOUBLE_FIELD = "double_field_5"; public void testDuelBestSpeed() throws IOException { - var baseline = new LegacyPerFieldMapperCodec(Lucene103Codec.Mode.BEST_SPEED, null, BigArrays.NON_RECYCLING_INSTANCE, null); + var baseline = new LegacyPerFieldMapperCodec(Lucene104Codec.Mode.BEST_SPEED, null, BigArrays.NON_RECYCLING_INSTANCE, null); var contender = new PerFieldMapperCodec(Zstd814StoredFieldsFormat.Mode.BEST_SPEED, null, BigArrays.NON_RECYCLING_INSTANCE, null); doTestDuel(baseline, contender); } public void testDuelBestCompression() throws IOException { - var baseline = new LegacyPerFieldMapperCodec(Lucene103Codec.Mode.BEST_COMPRESSION, null, BigArrays.NON_RECYCLING_INSTANCE, null); + var baseline = new LegacyPerFieldMapperCodec(Lucene104Codec.Mode.BEST_COMPRESSION, null, BigArrays.NON_RECYCLING_INSTANCE, null); var contender = new PerFieldMapperCodec( Zstd814StoredFieldsFormat.Mode.BEST_COMPRESSION, null, diff --git a/server/src/test/java/org/elasticsearch/index/codec/zstd/Zstd814BestCompressionStoredFieldsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/zstd/Zstd814BestCompressionStoredFieldsFormatTests.java index f89fa52256e15..a9404e6f94254 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/zstd/Zstd814BestCompressionStoredFieldsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/zstd/Zstd814BestCompressionStoredFieldsFormatTests.java @@ -11,11 +11,11 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.tests.index.BaseStoredFieldsFormatTestCase; -import org.elasticsearch.index.codec.Elasticsearch92Lucene103Codec; +import org.elasticsearch.index.codec.Elasticsearch93Lucene104Codec; public class Zstd814BestCompressionStoredFieldsFormatTests extends BaseStoredFieldsFormatTestCase { - private final Codec codec = new Elasticsearch92Lucene103Codec(Zstd814StoredFieldsFormat.Mode.BEST_COMPRESSION); + private final Codec codec = new Elasticsearch93Lucene104Codec(Zstd814StoredFieldsFormat.Mode.BEST_COMPRESSION); @Override protected Codec getCodec() { diff --git a/server/src/test/java/org/elasticsearch/index/codec/zstd/Zstd814BestSpeedStoredFieldsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/zstd/Zstd814BestSpeedStoredFieldsFormatTests.java index f3d120ed185e7..4bca8e2430b6a 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/zstd/Zstd814BestSpeedStoredFieldsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/zstd/Zstd814BestSpeedStoredFieldsFormatTests.java @@ -11,11 +11,11 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.tests.index.BaseStoredFieldsFormatTestCase; -import org.elasticsearch.index.codec.Elasticsearch92Lucene103Codec; +import org.elasticsearch.index.codec.Elasticsearch93Lucene104Codec; public class Zstd814BestSpeedStoredFieldsFormatTests extends BaseStoredFieldsFormatTestCase { - private final Codec codec = new Elasticsearch92Lucene103Codec(Zstd814StoredFieldsFormat.Mode.BEST_SPEED); + private final Codec codec = new Elasticsearch93Lucene104Codec(Zstd814StoredFieldsFormat.Mode.BEST_SPEED); @Override protected Codec getCodec() { diff --git a/server/src/test/java/org/elasticsearch/index/engine/RecoverySourcePruneMergePolicyTests.java b/server/src/test/java/org/elasticsearch/index/engine/RecoverySourcePruneMergePolicyTests.java index d70a2002d8acf..e948e81fdd5d4 100644 --- a/server/src/test/java/org/elasticsearch/index/engine/RecoverySourcePruneMergePolicyTests.java +++ b/server/src/test/java/org/elasticsearch/index/engine/RecoverySourcePruneMergePolicyTests.java @@ -9,7 +9,7 @@ package org.elasticsearch.index.engine; -import org.apache.lucene.codecs.lucene103.Lucene103Codec; +import org.apache.lucene.codecs.lucene104.Lucene104Codec; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.NumericDocValuesField; @@ -756,7 +756,7 @@ private static IndexWriterConfig createIndexWriterConfig(final IndexSettings ind if (IndexSettings.SYNTHETIC_ID.get(indexSettings.getSettings())) { iwc.setCodec( new ES93TSDBDefaultCompressionLucene103Codec( - new LegacyPerFieldMapperCodec(Lucene103Codec.Mode.BEST_SPEED, mapperService, BigArrays.NON_RECYCLING_INSTANCE, null) + new LegacyPerFieldMapperCodec(Lucene104Codec.Mode.BEST_SPEED, mapperService, BigArrays.NON_RECYCLING_INSTANCE, null) ) ); } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/CompletionFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/CompletionFieldMapperTests.java index 0428e37698190..b1f7b147eb24c 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/CompletionFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/CompletionFieldMapperTests.java @@ -16,7 +16,7 @@ import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexableField; import org.apache.lucene.search.Query; -import org.apache.lucene.search.suggest.document.Completion101PostingsFormat; +import org.apache.lucene.search.suggest.document.Completion104PostingsFormat; import org.apache.lucene.search.suggest.document.CompletionAnalyzer; import org.apache.lucene.search.suggest.document.ContextSuggestField; import org.apache.lucene.search.suggest.document.FuzzyCompletionQuery; @@ -149,7 +149,7 @@ protected IndexAnalyzers createIndexAnalyzers(IndexSettings indexSettings) { } public void testPostingsFormat() throws IOException { - final Class latestLuceneCPClass = Completion101PostingsFormat.class; + final Class latestLuceneCPClass = Completion104PostingsFormat.class; MapperService mapperService = createMapperService(fieldMapping(this::minimalMapping)); CodecService codecService = new CodecService(mapperService, BigArrays.NON_RECYCLING_INSTANCE, null); Codec codec = codecService.codec("default"); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 525060da6bf58..4b83ce37d29da 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -36,6 +36,7 @@ import org.elasticsearch.index.codec.PerFieldMapperCodec; import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.DocumentParsingException; @@ -80,6 +81,7 @@ import static org.elasticsearch.common.util.concurrent.EsExecutors.NODE_PROCESSORS_SETTING; import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.DYNAMIC_VISIT_RATIO; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DEFAULT_OVERSAMPLE; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasToString; @@ -520,8 +522,8 @@ protected void registerParameters(ParameterChecker checker) throws IOException { registerIndexOptionsUpdate( checker, b -> b.field("type", "dense_vector").field("dims", dims).field("index", true), - b -> b.field("type", "int4_hnsw").field("confidence_interval", 0.03).field("m", 4), - b -> b.field("type", "int4_hnsw").field("confidence_interval", 0.03).field("m", 100), + b -> b.field("type", "int4_hnsw").field("m", 4), + b -> b.field("type", "int4_hnsw").field("m", 100), hasToString(containsString("\"m\":100")) ); registerConflict( @@ -531,13 +533,6 @@ protected void registerParameters(ParameterChecker checker) throws IOException { b -> b.startObject("index_options").field("type", "int4_hnsw").field("m", 32).endObject(), b -> b.startObject("index_options").field("type", "int4_hnsw").field("m", 16).endObject() ); - registerConflict( - checker, - "index_options", - b -> b.field("type", "dense_vector").field("dims", dims).field("index", true), - b -> b.startObject("index_options").field("type", "int4_hnsw").endObject(), - b -> b.startObject("index_options").field("type", "int4_hnsw").field("confidence_interval", 0.3).endObject() - ); registerConflict( checker, "index_options", @@ -1958,21 +1953,24 @@ public void testKnnVectorsFormat() throws IOException { final int efConstruction = randomIntBetween(1, DEFAULT_BEAM_WIDTH + 10); boolean setM = randomBoolean(); boolean setEfConstruction = randomBoolean(); - MapperService mapperService = createMapperService(fieldMapping(b -> { - b.field("type", "dense_vector"); - b.field("dims", dims); - b.field("index", true); - b.field("similarity", "dot_product"); - b.startObject("index_options"); - b.field("type", "hnsw"); - if (setM) { - b.field("m", m); - } - if (setEfConstruction) { - b.field("ef_construction", efConstruction); - } - b.endObject(); - })); + MapperService mapperService = createMapperService( + IndexVersionUtils.getPreviousVersion(IndexVersions.UPGRADE_TO_LUCENE_10_4_0), + fieldMapping(b -> { + b.field("type", "dense_vector"); + b.field("dims", dims); + b.field("index", true); + b.field("similarity", "dot_product"); + b.startObject("index_options"); + b.field("type", "hnsw"); + if (setM) { + b.field("m", m); + } + if (setEfConstruction) { + b.field("ef_construction", efConstruction); + } + b.endObject(); + }) + ); CodecService codecService = new CodecService(mapperService, BigArrays.NON_RECYCLING_INSTANCE, null); Codec codec = codecService.codec("default"); KnnVectorsFormat knnVectorsFormat; @@ -1990,28 +1988,64 @@ public void testKnnVectorsFormat() throws IOException { + (setM ? m : DEFAULT_MAX_CONN) + ", beamWidth=" + (setEfConstruction ? efConstruction : DEFAULT_BEAM_WIDTH) + + ", hnswGraphThreshold=" + + ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD + ", flatVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat, format=" + "Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=" + "ES93FlatVectorScorer(delegate=Lucene99MemorySegmentFlatVectorsScorer()))))"; assertEquals(expectedString, knnVectorsFormat.toString()); } - public void testKnnQuantizedFlatVectorsFormat() throws IOException { - boolean setConfidenceInterval = randomBoolean(); - float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true); - for (String quantizedFlatFormat : new String[] { "int8_flat", "int4_flat" }) { - MapperService mapperService = createMapperService(fieldMapping(b -> { + public void testConfidenceIntervalDeprecationOnLatestIndexVersion() throws IOException { + DocumentMapper mapper = createDocumentMapper(IndexVersions.UPGRADE_TO_LUCENE_10_4_0, fieldMapping(b -> { + b.field("type", "dense_vector"); + b.field("dims", 6); + b.field("index", true); + b.field("similarity", "dot_product"); + b.startObject("index_options"); + b.field("type", "int8_hnsw"); + b.field("confidence_interval", 0.95f); + b.endObject(); + })); + assertTrue(mapper.mappingSource().string().contains("\"confidence_interval\":0.95")); + assertWarnings( + "Parameter [confidence_interval] in [index_options] for dense_vector field " + + "[field] is deprecated and will be removed in a future version" + ); + } + + public void testConfidenceIntervalNoDeprecationBeforeLatestIndexVersion() throws IOException { + DocumentMapper mapper = createDocumentMapper( + IndexVersionUtils.getPreviousVersion(IndexVersions.UPGRADE_TO_LUCENE_10_4_0), + fieldMapping(b -> { b.field("type", "dense_vector"); - b.field("dims", dims); + b.field("dims", 6); b.field("index", true); b.field("similarity", "dot_product"); b.startObject("index_options"); - b.field("type", quantizedFlatFormat); - if (setConfidenceInterval) { - b.field("confidence_interval", confidenceInterval); - } + b.field("type", "int8_hnsw"); + b.field("confidence_interval", 0.95f); b.endObject(); - })); + }) + ); + assertTrue(mapper.mappingSource().string().contains("\"confidence_interval\":0.95")); + assertWarnings(); + } + + public void testKnnQuantizedFlatVectorsFormat() throws IOException { + for (String quantizedFlatFormat : new String[] { "int8_flat", "int4_flat" }) { + MapperService mapperService = createMapperService( + IndexVersionUtils.getPreviousVersion(IndexVersions.UPGRADE_TO_LUCENE_10_4_0), + fieldMapping(b -> { + b.field("type", "dense_vector"); + b.field("dims", dims); + b.field("index", true); + b.field("similarity", "dot_product"); + b.startObject("index_options"); + b.field("type", quantizedFlatFormat); + b.endObject(); + }) + ); CodecService codecService = new CodecService(mapperService, BigArrays.NON_RECYCLING_INSTANCE, null); Codec codec = codecService.codec("default"); KnnVectorsFormat knnVectorsFormat; @@ -2026,45 +2060,43 @@ public void testKnnQuantizedFlatVectorsFormat() throws IOException { knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field"); } VectorScorerFactory factory = VectorScorerFactory.instance().orElse(null); - String expectedString = "ES93ScalarQuantizedVectorsFormat(name=ES93ScalarQuantizedVectorsFormat," - + " confidenceInterval=" - + (setConfidenceInterval ? Float.toString(confidenceInterval) : (quantizedFlatFormat.equals("int4_flat") ? "0.0" : null)) - + ", bits=" - + (quantizedFlatFormat.equals("int4_flat") ? 4 : 7) - + ", compressed=" - + quantizedFlatFormat.equals("int4_flat") - + ", flatVectorScorer=ESQuantizedFlatVectorsScorer(" - + "delegate=ScalarQuantizedVectorScorer(nonQuantizedDelegate=" - + "ES93FlatVectorScorer(delegate=Lucene99MemorySegmentFlatVectorsScorer()))" - + ", factory=" - + (factory != null ? factory : "null") - + "), " - + "rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat, format=" - + "Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=" - + "ES93FlatVectorScorer(delegate=Lucene99MemorySegmentFlatVectorsScorer()))))"; - assertThat(knnVectorsFormat, hasToString(expectedString)); + String encoding = quantizedFlatFormat.equals("int4_flat") ? "PACKED_NIBBLE" : "SEVEN_BIT"; + assertThat( + knnVectorsFormat, + hasToString( + allOf( + containsString("ES94ScalarQuantizedVectorsFormat(name=ES94ScalarQuantizedVectorsFormat"), + containsString("encoding=" + encoding), + containsString("flatVectorScorer=ESQuantizedFlatVectorsScorer("), + containsString("factory=" + (factory != null ? factory : "null")), + containsString( + "rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat, format=" + + "Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=" + + "ES93FlatVectorScorer(delegate=Lucene99MemorySegmentFlatVectorsScorer())))" + ) + ) + ) + ); } } public void testKnnQuantizedHNSWVectorsFormat() throws IOException { final int m = randomIntBetween(1, DEFAULT_MAX_CONN + 10); final int efConstruction = randomIntBetween(1, DEFAULT_BEAM_WIDTH + 10); - boolean setConfidenceInterval = randomBoolean(); - float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true); - MapperService mapperService = createMapperService(fieldMapping(b -> { - b.field("type", "dense_vector"); - b.field("dims", dims); - b.field("index", true); - b.field("similarity", "dot_product"); - b.startObject("index_options"); - b.field("type", "int8_hnsw"); - b.field("m", m); - b.field("ef_construction", efConstruction); - if (setConfidenceInterval) { - b.field("confidence_interval", confidenceInterval); - } - b.endObject(); - })); + MapperService mapperService = createMapperService( + IndexVersionUtils.getPreviousVersion(IndexVersions.UPGRADE_TO_LUCENE_10_4_0), + fieldMapping(b -> { + b.field("type", "dense_vector"); + b.field("dims", dims); + b.field("index", true); + b.field("similarity", "dot_product"); + b.startObject("index_options"); + b.field("type", "int8_hnsw"); + b.field("m", m); + b.field("ef_construction", efConstruction); + b.endObject(); + }) + ); CodecService codecService = new CodecService(mapperService, BigArrays.NON_RECYCLING_INSTANCE, null); Codec codec = codecService.codec("default"); KnnVectorsFormat knnVectorsFormat; @@ -2079,39 +2111,49 @@ public void testKnnQuantizedHNSWVectorsFormat() throws IOException { knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field"); } VectorScorerFactory factory = VectorScorerFactory.instance().orElse(null); - String expectedString = "ES93HnswScalarQuantizedVectorsFormat(name=ES93HnswScalarQuantizedVectorsFormat, maxConn=" - + m - + ", beamWidth=" - + efConstruction - + ", flatVectorFormat=ES93ScalarQuantizedVectorsFormat(name=ES93ScalarQuantizedVectorsFormat, confidenceInterval=" - + (setConfidenceInterval ? confidenceInterval : null) - + ", bits=7, compressed=false, " - + "flatVectorScorer=ESQuantizedFlatVectorsScorer(delegate=" - + "ScalarQuantizedVectorScorer(nonQuantizedDelegate=" - + "ES93FlatVectorScorer(delegate=Lucene99MemorySegmentFlatVectorsScorer())), " - + "factory=" - + (factory != null ? factory : "null") - + "), rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat, format=" - + "Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=" - + "ES93FlatVectorScorer(delegate=Lucene99MemorySegmentFlatVectorsScorer())))))"; - assertThat(knnVectorsFormat, hasToString(expectedString)); + assertThat( + knnVectorsFormat, + hasToString( + allOf( + startsWith( + "ES94HnswScalarQuantizedVectorsFormat(name=ES94HnswScalarQuantizedVectorsFormat, maxConn=" + + m + + ", beamWidth=" + + efConstruction + + ", hnswGraphThreshold=" + + ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD + + ", flatVectorFormat=ES94ScalarQuantizedVectorsFormat(name=ES94ScalarQuantizedVectorsFormat" + ), + containsString("encoding=SEVEN_BIT"), + containsString("factory=" + (factory != null ? factory : "null")), + containsString( + "rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat, format=" + + "Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=" + + "ES93FlatVectorScorer(delegate=Lucene99MemorySegmentFlatVectorsScorer())))" + ) + ) + ) + ); } public void testKnnBBQHNSWVectorsFormat() throws IOException { final int m = randomIntBetween(1, DEFAULT_MAX_CONN + 10); final int efConstruction = randomIntBetween(1, DEFAULT_BEAM_WIDTH + 10); final int dims = randomIntBetween(64, 4096); - MapperService mapperService = createMapperService(fieldMapping(b -> { - b.field("type", "dense_vector"); - b.field("dims", dims); - b.field("index", true); - b.field("similarity", "dot_product"); - b.startObject("index_options"); - b.field("type", "bbq_hnsw"); - b.field("m", m); - b.field("ef_construction", efConstruction); - b.endObject(); - })); + MapperService mapperService = createMapperService( + IndexVersionUtils.getPreviousVersion(IndexVersions.UPGRADE_TO_LUCENE_10_4_0), + fieldMapping(b -> { + b.field("type", "dense_vector"); + b.field("dims", dims); + b.field("index", true); + b.field("similarity", "dot_product"); + b.startObject("index_options"); + b.field("type", "bbq_hnsw"); + b.field("m", m); + b.field("ef_construction", efConstruction); + b.endObject(); + }) + ); CodecService codecService = new CodecService(mapperService, BigArrays.NON_RECYCLING_INSTANCE, null); Codec codec = codecService.codec("default"); KnnVectorsFormat knnVectorsFormat; @@ -2125,15 +2167,17 @@ public void testKnnBBQHNSWVectorsFormat() throws IOException { assertThat(codec, instanceOf(LegacyPerFieldMapperCodec.class)); knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field"); } - String expectedString = "ES93HnswBinaryQuantizedVectorsFormat(name=ES93HnswBinaryQuantizedVectorsFormat, maxConn=" + String expectedPrefix = "ES93HnswBinaryQuantizedVectorsFormat(name=ES93HnswBinaryQuantizedVectorsFormat, maxConn=" + m + ", beamWidth=" + efConstruction + + ", hnswGraphThreshold=" + + ES93HnswBinaryQuantizedVectorsFormat.BBQ_HNSW_GRAPH_THRESHOLD + ", flatVectorFormat=ES93BinaryQuantizedVectorsFormat(" + "name=ES93BinaryQuantizedVectorsFormat, " + "rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat," + " format=Lucene99FlatVectorsFormat"; - assertThat(knnVectorsFormat, hasToString(startsWith(expectedString))); + assertThat(knnVectorsFormat, hasToString(startsWith(expectedPrefix))); } public void testInvalidVectorDimensionsBBQ() { @@ -2155,8 +2199,6 @@ public void testInvalidVectorDimensionsBBQ() { public void testKnnHalfByteQuantizedHNSWVectorsFormat() throws IOException { final int m = randomIntBetween(1, DEFAULT_MAX_CONN + 10); final int efConstruction = randomIntBetween(1, DEFAULT_BEAM_WIDTH + 10); - boolean setConfidenceInterval = randomBoolean(); - float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true); MapperService mapperService = createMapperService(fieldMapping(b -> { b.field("type", "dense_vector"); b.field("dims", dims); @@ -2166,9 +2208,6 @@ public void testKnnHalfByteQuantizedHNSWVectorsFormat() throws IOException { b.field("type", "int4_hnsw"); b.field("m", m); b.field("ef_construction", efConstruction); - if (setConfidenceInterval) { - b.field("confidence_interval", confidenceInterval); - } b.endObject(); })); CodecService codecService = new CodecService(mapperService, BigArrays.NON_RECYCLING_INSTANCE, null); @@ -2185,20 +2224,29 @@ public void testKnnHalfByteQuantizedHNSWVectorsFormat() throws IOException { knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field"); } VectorScorerFactory factory = VectorScorerFactory.instance().orElse(null); - String expectedString = "ES93HnswScalarQuantizedVectorsFormat(name=ES93HnswScalarQuantizedVectorsFormat, maxConn=" - + m - + ", beamWidth=" - + efConstruction - + ", flatVectorFormat=ES93ScalarQuantizedVectorsFormat(name=ES93ScalarQuantizedVectorsFormat, confidenceInterval=" - + (setConfidenceInterval ? confidenceInterval : 0.0f) - + ", bits=4, compressed=true, flatVectorScorer=ESQuantizedFlatVectorsScorer(delegate=" - + "ScalarQuantizedVectorScorer(nonQuantizedDelegate=" - + "ES93FlatVectorScorer(delegate=Lucene99MemorySegmentFlatVectorsScorer())), factory=" - + (factory != null ? factory : "null") - + "), rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat, format=" - + "Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=" - + "ES93FlatVectorScorer(delegate=Lucene99MemorySegmentFlatVectorsScorer())))))"; - assertThat(knnVectorsFormat, hasToString(expectedString)); + assertThat( + knnVectorsFormat, + hasToString( + allOf( + startsWith( + "ES94HnswScalarQuantizedVectorsFormat(name=ES94HnswScalarQuantizedVectorsFormat, maxConn=" + + m + + ", beamWidth=" + + efConstruction + + ", hnswGraphThreshold=" + + ES93HnswVectorsFormat.HNSW_GRAPH_THRESHOLD + + ", flatVectorFormat=ES94ScalarQuantizedVectorsFormat(name=ES94ScalarQuantizedVectorsFormat" + ), + containsString("encoding=PACKED_NIBBLE"), + containsString("factory=" + (factory != null ? factory : "null")), + containsString( + "rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat, format=" + + "Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=" + + "ES93FlatVectorScorer(delegate=Lucene99MemorySegmentFlatVectorsScorer())))" + ) + ) + ) + ); } public void testInvalidVectorDimensions() { @@ -2219,7 +2267,7 @@ public void testInvalidVectorDimensions() { public void testPushingDownExecutorAndThreads() { TestDenseVectorIndexOptions testIndexOptions = new TestDenseVectorIndexOptions( - new DenseVectorFieldMapper.HnswIndexOptions(16, 200) + new DenseVectorFieldMapper.HnswIndexOptions(16, 200, -1) ); var mapper = new DenseVectorFieldMapper.Builder("field", IndexVersion.current(), true, false, List.of()).indexOptions( testIndexOptions diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 2349b7ed50d75..7950da45d9ca1 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -66,7 +66,7 @@ private static DenseVectorFieldMapper.RescoreVector randomRescoreVector() { private static DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsNonQuantized() { return randomFrom( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), + new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), -1), new DenseVectorFieldMapper.FlatIndexOptions() ); } @@ -74,26 +74,20 @@ private static DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptions public static DenseVectorFieldMapper.DenseVectorIndexOptions randomFlatIndexOptions() { return randomFrom( new DenseVectorFieldMapper.FlatIndexOptions(), - new DenseVectorFieldMapper.Int8FlatIndexOptions( - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), - randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) - ), - new DenseVectorFieldMapper.Int4FlatIndexOptions( - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), - randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) - ) + new DenseVectorFieldMapper.Int8FlatIndexOptions(randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())), + new DenseVectorFieldMapper.Int4FlatIndexOptions(randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())) ); } public static DenseVectorFieldMapper.DenseVectorIndexOptions randomGpuSupportedIndexOptions() { return randomFrom( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 3199)), + new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 3199), -1), new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 3199), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), randomBoolean(), - randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()), + -1 ) ); } @@ -101,35 +95,34 @@ public static DenseVectorFieldMapper.DenseVectorIndexOptions randomGpuSupportedI public static DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsAll() { List options = new ArrayList<>( Arrays.asList( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), + new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), -1), new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), randomBoolean(), - randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()), + -1 ), new DenseVectorFieldMapper.Int4HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), randomBoolean(), - randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()), + -1 ), new DenseVectorFieldMapper.FlatIndexOptions(), new DenseVectorFieldMapper.Int8FlatIndexOptions( - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.Int4FlatIndexOptions( - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.BBQHnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomBoolean(), - randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()), + -1 ), new DenseVectorFieldMapper.BBQFlatIndexOptions( randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) @@ -165,22 +158,23 @@ private DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsHnswQua new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), randomBoolean(), - rescoreVector + rescoreVector, + -1 ), new DenseVectorFieldMapper.Int4HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), randomBoolean(), - rescoreVector + rescoreVector, + -1 ), new DenseVectorFieldMapper.BBQHnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomBoolean(), - rescoreVector + rescoreVector, + -1 ) ); } diff --git a/server/src/test/java/org/elasticsearch/lucene/comparators/XSkipBlockRangeIteratorTests.java b/server/src/test/java/org/elasticsearch/lucene/comparators/XSkipBlockRangeIteratorTests.java deleted file mode 100644 index e8b4716fafbbb..0000000000000 --- a/server/src/test/java/org/elasticsearch/lucene/comparators/XSkipBlockRangeIteratorTests.java +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.lucene.comparators; - -import org.apache.lucene.index.DocValuesSkipper; -import org.apache.lucene.index.NumericDocValues; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.util.FixedBitSet; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.test.ESTestCase; - -import java.io.IOException; - -public class XSkipBlockRangeIteratorTests extends ESTestCase { - - public void testLuceneVersion() { - assertFalse( - "Remove this class after upgrading to Lucene 10.4", - IndexVersion.current().luceneVersion().onOrAfter(org.apache.lucene.util.Version.fromBits(10, 4, 0)) - ); - } - - public void testSkipBlockRangeIterator() throws Exception { - - DocValuesSkipper skipper = docValuesSkipper(10, 20, true); - XSkipBlockRangeIterator it = new XSkipBlockRangeIterator(skipper, 10, 20); - - assertEquals(0, it.nextDoc()); - assertEquals(256, it.docIDRunEnd()); - assertEquals(100, it.advance(100)); - assertEquals(768, it.advance(300)); - assertEquals(1024, it.docIDRunEnd()); - assertEquals(1100, it.advance(1100)); - assertEquals(1280, it.docIDRunEnd()); - assertEquals(1792, it.advance(1500)); - assertEquals(2048, it.docIDRunEnd()); - assertEquals(DocIdSetIterator.NO_MORE_DOCS, it.advance(2050)); - } - - public void testIntoBitSet() throws Exception { - DocValuesSkipper skipper = docValuesSkipper(10, 20, true); - XSkipBlockRangeIterator it = new XSkipBlockRangeIterator(skipper, 10, 20); - assertEquals(768, it.advance(300)); - FixedBitSet bitSet = new FixedBitSet(2048); - it.intoBitSet(1500, bitSet, 768); - - FixedBitSet expected = new FixedBitSet(2048); - expected.set(0, 512); - - assertEquals(expected, bitSet); - } - - protected static NumericDocValues docValues(long queryMin, long queryMax) { - return new NumericDocValues() { - - int doc = -1; - - @Override - public boolean advanceExact(int target) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - if (target < 1024) { - // dense up to 1024 - return doc = target; - } else if (doc < 2047) { - // 50% docs have a value up to 2048 - return doc = target + (target & 1); - } else { - return doc = DocIdSetIterator.NO_MORE_DOCS; - } - } - - @Override - public long longValue() throws IOException { - int d = doc % 1024; - if (d < 128) { - return (queryMin + queryMax) >> 1; - } else if (d < 256) { - return queryMax + 1; - } else if (d < 512) { - return queryMin - 1; - } else { - return switch ((d / 2) % 3) { - case 0 -> queryMin - 1; - case 1 -> queryMax + 1; - case 2 -> (queryMin + queryMax) >> 1; - default -> throw new AssertionError(); - }; - } - } - - @Override - public long cost() { - return 42; - } - }; - } - - /** - * Fake skipper over a NumericDocValues field built by an equivalent call to {@link - * #docValues(long, long)} - */ - protected static DocValuesSkipper docValuesSkipper(long queryMin, long queryMax, boolean doLevels) { - return new DocValuesSkipper() { - - int doc = -1; - - @Override - public void advance(int target) throws IOException { - doc = target; - } - - @Override - public int numLevels() { - return doLevels ? 3 : 1; - } - - @Override - public int minDocID(int level) { - int rangeLog = 9 - numLevels() + level; - - // the level is the log2 of the interval - if (doc < 0) { - return -1; - } else if (doc >= 2048) { - return DocIdSetIterator.NO_MORE_DOCS; - } else { - int mask = (1 << rangeLog) - 1; - // prior multiple of 2^level - return doc & ~mask; - } - } - - @Override - public int maxDocID(int level) { - int rangeLog = 9 - numLevels() + level; - - int minDocID = minDocID(level); - return switch (minDocID) { - case -1 -> -1; - case DocIdSetIterator.NO_MORE_DOCS -> DocIdSetIterator.NO_MORE_DOCS; - default -> minDocID + (1 << rangeLog) - 1; - }; - } - - @Override - @SuppressWarnings("DuplicateBranches") - public long minValue(int level) { - int d = doc % 1024; - if (d < 128) { - return queryMin; - } else if (d < 256) { - return queryMax + 1; - } else if (d < 768) { - return queryMin - 1; - } else { - return queryMin - 1; - } - } - - @Override - public long maxValue(int level) { - int d = doc % 1024; - if (d < 128) { - return queryMax; - } else if (d < 256) { - return queryMax + 1; - } else if (d < 768) { - return queryMin - 1; - } else { - return queryMax + 1; - } - } - - @Override - public int docCount(int level) { - int rangeLog = 9 - numLevels() + level; - - if (doc < 1024) { - return 1 << rangeLog; - } else { - // half docs have a value - return 1 << rangeLog >> 1; - } - } - - @Override - public long minValue() { - return Long.MIN_VALUE; - } - - @Override - public long maxValue() { - return Long.MAX_VALUE; - } - - @Override - public int docCount() { - return 1024 + 1024 / 2; - } - }; - } -} diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 410d2eb2ec28c..8afccdef58022 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -24,6 +24,8 @@ import org.apache.lucene.queries.function.FunctionScoreQuery; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.ConjunctionUtils; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.FullPrecisionFloatVectorSimilarityValuesSource; @@ -39,7 +41,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.util.Bits; import org.elasticsearch.common.lucene.search.Queries; -import org.elasticsearch.index.codec.Elasticsearch92Lucene103Codec; +import org.elasticsearch.index.codec.Elasticsearch93Lucene104Codec; import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; @@ -340,7 +342,7 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th randomBoolean() ) ); - iwc.setCodec(new Elasticsearch92Lucene103Codec(randomFrom(Zstd814StoredFieldsFormat.Mode.values())) { + iwc.setCodec(new Elasticsearch93Lucene104Codec(randomFrom(Zstd814StoredFieldsFormat.Mode.values())) { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { return format; @@ -390,7 +392,7 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { if (values == null) { return null; } - return new RegularFloatVectorValues(values); + return new SingleFloatVectorValues(values); } }; } @@ -408,17 +410,26 @@ public CacheHelper getReaderCacheHelper() { } } - private static final class RegularFloatVectorValues extends FloatVectorValues { + /** + * A wrapper around FloatVectorValues that ensures that the bulk scoring path uses the single scoring method. + * Used to test that the single and bulk scoring paths return the same scores. + */ + private static final class SingleFloatVectorValues extends FloatVectorValues { private final FloatVectorValues in; - RegularFloatVectorValues(FloatVectorValues in) { + SingleFloatVectorValues(FloatVectorValues in) { this.in = in; } @Override public VectorScorer scorer(float[] target) throws IOException { - return in.scorer(target); + return new SingleVectorScorer(in.scorer(target)); + } + + @Override + public VectorScorer rescorer(float[] target) throws IOException { + return new SingleVectorScorer(in.rescorer(target)); } @Override @@ -443,7 +454,7 @@ public float[] vectorValue(int ord) throws IOException { @Override public FloatVectorValues copy() throws IOException { - return new RegularFloatVectorValues(in.copy()); + return new SingleFloatVectorValues(in.copy()); } @Override @@ -456,4 +467,48 @@ public int size() { return in.size(); } } + + private static final class SingleVectorScorer implements VectorScorer { + private final VectorScorer in; + + SingleVectorScorer(VectorScorer in) { + this.in = in; + } + + @Override + public float score() throws IOException { + return in.score(); + } + + @Override + public DocIdSetIterator iterator() { + return in.iterator(); + } + + @Override + public VectorScorer.Bulk bulk(DocIdSetIterator matchingDocs) throws IOException { + final DocIdSetIterator iterator = matchingDocs == null + ? iterator() + : ConjunctionUtils.createConjunction(List.of(matchingDocs, iterator()), List.of()); + if (iterator.docID() == -1) { + iterator.nextDoc(); + } + return (upTo, liveDocs, buffer) -> { + assert upTo > 0; + buffer.growNoCopy(VectorScorer.DEFAULT_BULK_BATCH_SIZE); + int size = 0; + float maxScore = Float.NEGATIVE_INFINITY; + for (int doc = iterator.docID(); doc < upTo && size < VectorScorer.DEFAULT_BULK_BATCH_SIZE; doc = iterator.nextDoc()) { + if (liveDocs == null || liveDocs.get(doc)) { + buffer.docs[size] = doc; + buffer.features[size] = score(); + maxScore = Math.max(maxScore, buffer.features[size]); + ++size; + } + } + buffer.size = size; + return maxScore; + }; + } + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/index/engine/frozen/FrozenEngine.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/index/engine/frozen/FrozenEngine.java index 965ec067a7ee5..3f5eb13a1d95c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/index/engine/frozen/FrozenEngine.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/index/engine/frozen/FrozenEngine.java @@ -41,6 +41,7 @@ import java.util.Set; import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; import java.util.function.Function; /** @@ -113,16 +114,31 @@ protected DirectoryReader doOpenIfChanged() { return null; } + @Override + protected DirectoryReader doOpenIfChanged(ExecutorService executorService) { + return null; + } + @Override protected DirectoryReader doOpenIfChanged(IndexCommit commit) { return null; } + @Override + protected DirectoryReader doOpenIfChanged(IndexCommit commit, ExecutorService executorService) { + return null; + } + @Override protected DirectoryReader doOpenIfChanged(IndexWriter writer, boolean applyAllDeletes) { return null; } + @Override + protected DirectoryReader doOpenIfChanged(IndexWriter writer, boolean applyAllDeletes, ExecutorService executorService) { + return null; + } + @Override public long getVersion() { return 0; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/index/engine/frozen/RewriteCachingDirectoryReader.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/index/engine/frozen/RewriteCachingDirectoryReader.java index e26c7577672f5..ca435a1a74aab 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/index/engine/frozen/RewriteCachingDirectoryReader.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/index/engine/frozen/RewriteCachingDirectoryReader.java @@ -36,6 +36,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutorService; /** * This special DirectoryReader is used to handle can_match requests against frozen indices. @@ -65,16 +66,31 @@ protected DirectoryReader doOpenIfChanged() { throw new UnsupportedOperationException(); } + @Override + protected DirectoryReader doOpenIfChanged(ExecutorService executorService) { + throw new UnsupportedOperationException(); + } + @Override protected DirectoryReader doOpenIfChanged(IndexCommit commit) { throw new UnsupportedOperationException(); } + @Override + protected DirectoryReader doOpenIfChanged(IndexCommit commit, ExecutorService executorService) { + throw new UnsupportedOperationException(); + } + @Override protected DirectoryReader doOpenIfChanged(IndexWriter writer, boolean applyAllDeletes) { throw new UnsupportedOperationException(); } + @Override + protected DirectoryReader doOpenIfChanged(IndexWriter writer, boolean applyAllDeletes, ExecutorService executorService) { + throw new UnsupportedOperationException(); + } + @Override public long getVersion() { throw new UnsupportedOperationException(); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec index dfcb6f5e18bd8..7e7ec5e374269 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec @@ -85,7 +85,7 @@ be excellent to each other all we have to decide is what to do with the time that is given to us ; -text_embedding with multiple knn queries in fork +text_embedding with multiple knn queries in fork#[skip:-9.4.99,reason:vector scoring order differs on older versions] required_capability: text_embedding_function required_capability: dense_vector_field_type_released required_capability: knn_function_v5 @@ -100,8 +100,8 @@ FROM dense_vector_text METADATA _score ; text_field:text | query_embedding:dense_vector | _fork:keyword -be excellent to each other | [45.0, 55.0, 54.0] | fork1 live long and prosper | [50.0, 57.0, 56.0] | fork2 +be excellent to each other | [45.0, 55.0, 54.0] | fork1 live long and prosper | [45.0, 55.0, 54.0] | fork1 be excellent to each other | [50.0, 57.0, 56.0] | fork2 all we have to decide is what to do with the time that is given to us | [45.0, 55.0, 54.0] | fork1 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/SearchContextStats.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/SearchContextStats.java index 95f03cab47e3c..b9f44d14e571c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/SearchContextStats.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/SearchContextStats.java @@ -18,7 +18,6 @@ import org.apache.lucene.index.PointValues; import org.apache.lucene.index.Term; import org.apache.lucene.index.Terms; -import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.NumericUtils; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -305,12 +304,12 @@ private static Long nullableMax(final Long a, final Long b) { // TODO: replace these helpers with a unified Lucene min/max API once https://github.com/apache/lucene/issues/15740 is resolved private static Long docValuesSkipperMinValue(final LeafReaderContext leafContext, final String field) throws IOException { - long value = DocValuesSkipper.globalMinValue(new IndexSearcher(leafContext.reader()), field); + long value = DocValuesSkipper.globalMinValue(leafContext.reader(), field); return (value == Long.MAX_VALUE || value == Long.MIN_VALUE) ? null : value; } private static Long docValuesSkipperMaxValue(final LeafReaderContext leafContext, final String field) throws IOException { - long value = DocValuesSkipper.globalMaxValue(new IndexSearcher(leafContext.reader()), field); + long value = DocValuesSkipper.globalMaxValue(leafContext.reader(), field); return (value == Long.MAX_VALUE || value == Long.MIN_VALUE) ? null : value; } diff --git a/x-pack/plugin/frozen-indices/src/test/java/org/elasticsearch/index/engine/frozen/RewriteCachingDirectoryReaderTests.java b/x-pack/plugin/frozen-indices/src/test/java/org/elasticsearch/index/engine/frozen/RewriteCachingDirectoryReaderTests.java index 7182d46c35bb0..6c0e0318f5284 100644 --- a/x-pack/plugin/frozen-indices/src/test/java/org/elasticsearch/index/engine/frozen/RewriteCachingDirectoryReaderTests.java +++ b/x-pack/plugin/frozen-indices/src/test/java/org/elasticsearch/index/engine/frozen/RewriteCachingDirectoryReaderTests.java @@ -14,7 +14,6 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.PointValues; -import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.store.Directory; import org.elasticsearch.index.mapper.DateFieldMapper; import org.elasticsearch.index.mapper.IndexType; @@ -49,8 +48,6 @@ public void testGetMinMaxPackedValue() throws IOException { } try (DirectoryReader reader = DirectoryReader.open(writer)) { RewriteCachingDirectoryReader cachingDirectoryReader = new RewriteCachingDirectoryReader(dir, reader.leaves(), null); - IndexSearcher searcher = new IndexSearcher(reader); - IndexSearcher cachingSearcher = new IndexSearcher(cachingDirectoryReader); if (rarely) { assertArrayEquals( @@ -63,16 +60,16 @@ public void testGetMinMaxPackedValue() throws IOException { ); assertEquals(PointValues.size(reader, "rarely"), PointValues.size(cachingDirectoryReader, "rarely")); assertEquals( - DocValuesSkipper.globalDocCount(searcher, "rarely_skipper"), - DocValuesSkipper.globalDocCount(cachingSearcher, "rarely_skipper") + DocValuesSkipper.globalDocCount(reader, "rarely_skipper"), + DocValuesSkipper.globalDocCount(cachingDirectoryReader, "rarely_skipper") ); assertEquals( - DocValuesSkipper.globalMaxValue(searcher, "rarely_skipper"), - DocValuesSkipper.globalMaxValue(cachingSearcher, "rarely_skipper") + DocValuesSkipper.globalMaxValue(reader, "rarely_skipper"), + DocValuesSkipper.globalMaxValue(cachingDirectoryReader, "rarely_skipper") ); assertEquals( - DocValuesSkipper.globalMinValue(searcher, "rarely_skipper"), - DocValuesSkipper.globalMinValue(cachingSearcher, "rarely_skipper") + DocValuesSkipper.globalMinValue(reader, "rarely_skipper"), + DocValuesSkipper.globalMinValue(cachingDirectoryReader, "rarely_skipper") ); } assertArrayEquals( @@ -97,16 +94,16 @@ public void testGetMinMaxPackedValue() throws IOException { assertEquals(PointValues.size(reader, "test_const"), PointValues.size(cachingDirectoryReader, "test_const")); assertEquals( - DocValuesSkipper.globalDocCount(searcher, "skipper"), - DocValuesSkipper.globalDocCount(cachingSearcher, "skipper") + DocValuesSkipper.globalDocCount(reader, "skipper"), + DocValuesSkipper.globalDocCount(cachingDirectoryReader, "skipper") ); assertEquals( - DocValuesSkipper.globalMinValue(searcher, "skipper"), - DocValuesSkipper.globalMinValue(cachingSearcher, "skipper") + DocValuesSkipper.globalMinValue(reader, "skipper"), + DocValuesSkipper.globalMinValue(cachingDirectoryReader, "skipper") ); assertEquals( - DocValuesSkipper.globalMaxValue(searcher, "skipper"), - DocValuesSkipper.globalMaxValue(cachingSearcher, "skipper") + DocValuesSkipper.globalMaxValue(reader, "skipper"), + DocValuesSkipper.globalMaxValue(cachingDirectoryReader, "skipper") ); } } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java index eb187961225b9..0e088454104b7 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java @@ -119,9 +119,9 @@ public void testValidateIndexOptionsWithBasicLicense() throws Exception { IndexOptions indexOptions = new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - null, randomBoolean(), - null + null, + -1 ); assertAcked( safeGet(prepareCreate(INDEX_NAME).setMapping(generateMapping(inferenceFieldName, inferenceId, indexOptions)).execute()) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index e1b414a23a17c..937ae6f19b6f4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -1495,7 +1495,7 @@ public static DenseVectorFieldMapper.DenseVectorIndexOptions defaultBbqHnswDense int m = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; int efConstruction = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; DenseVectorFieldMapper.RescoreVector rescoreVector = new DenseVectorFieldMapper.RescoreVector(DEFAULT_RESCORE_OVERSAMPLE); - return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, false, rescoreVector); + return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, false, rescoreVector, -1); } static SemanticTextIndexOptions defaultIndexOptions(IndexVersion indexVersionCreated, MinimalServiceSettings modelSettings) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index ca6fbdd8ef976..9e26152a5f758 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -1921,7 +1921,7 @@ private static DenseVectorFieldMapper.DenseVectorIndexOptions defaultDenseVector // These are the default index options for dense_vector fields, and used for semantic_text fields incompatible with BBQ. int m = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; int efConstruction = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; - return new DenseVectorFieldMapper.Int8HnswIndexOptions(m, efConstruction, null, false, null); + return new DenseVectorFieldMapper.Int8HnswIndexOptions(m, efConstruction, false, null, -1); } private static SemanticTextIndexOptions defaultDenseVectorSemanticIndexOptions() { @@ -1932,7 +1932,7 @@ private static DenseVectorFieldMapper.DenseVectorIndexOptions defaultBbqHnswDens int m = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; int efConstruction = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; DenseVectorFieldMapper.RescoreVector rescoreVector = new DenseVectorFieldMapper.RescoreVector(DEFAULT_RESCORE_OVERSAMPLE); - return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, false, rescoreVector); + return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, false, rescoreVector, -1); } private static SemanticTextIndexOptions defaultBbqHnswSemanticTextIndexOptions() { @@ -2021,7 +2021,7 @@ public void testDefaultIndexOptions() throws IOException { null, new SemanticTextIndexOptions( SemanticTextIndexOptions.SupportedIndexOptions.DENSE_VECTOR, - new DenseVectorFieldMapper.Int4HnswIndexOptions(25, 100, null, false, null) + new DenseVectorFieldMapper.Int4HnswIndexOptions(25, 100, false, null, -1) ) ); @@ -2130,7 +2130,6 @@ public void testSpecifiedDenseVectorIndexOptions() throws IOException { b.field("type", "int4_hnsw"); b.field("m", 20); b.field("ef_construction", 90); - b.field("confidence_interval", 0.4); b.endObject(); b.endObject(); }), useLegacyFormat, IndexVersions.INFERENCE_METADATA_FIELDS_BACKPORT); @@ -2141,7 +2140,7 @@ public void testSpecifiedDenseVectorIndexOptions() throws IOException { null, new SemanticTextIndexOptions( SemanticTextIndexOptions.SupportedIndexOptions.DENSE_VECTOR, - new DenseVectorFieldMapper.Int4HnswIndexOptions(20, 90, 0.4f, false, null) + new DenseVectorFieldMapper.Int4HnswIndexOptions(20, 90, false, null, -1) ) ); @@ -2168,7 +2167,7 @@ public void testSpecifiedDenseVectorIndexOptions() throws IOException { null, new SemanticTextIndexOptions( SemanticTextIndexOptions.SupportedIndexOptions.DENSE_VECTOR, - new DenseVectorFieldMapper.Int4HnswIndexOptions(16, 100, 0f, false, null) + new DenseVectorFieldMapper.Int4HnswIndexOptions(16, 100, false, null, -1) ) ); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_field_mapping.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_field_mapping.yml index 3d6143f1d4ba7..59887c58c4030 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_field_mapping.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_field_mapping.yml @@ -600,7 +600,6 @@ setup: type: int8_hnsw m: 20 ef_construction: 100 - confidence_interval: 1.0 - do: indices.get_mapping: @@ -609,7 +608,6 @@ setup: - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.type": "int8_hnsw" } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.m": 20 } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.ef_construction": 100 } - - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.confidence_interval": 1.0 } - do: index: @@ -641,7 +639,6 @@ setup: - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.type": int8_hnsw } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.m": 20 } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.ef_construction": 100 } - - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.confidence_interval": 1.0 } --- "Specifying incompatible dense vector index options will fail": @@ -811,7 +808,6 @@ setup: type: int8_hnsw m: 16 ef_construction: 100 - confidence_interval: 1.0 - do: indices.get_mapping: @@ -820,7 +816,6 @@ setup: - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.type": "int8_hnsw" } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.m": 16 } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.ef_construction": 100 } - - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.confidence_interval": 1.0 } - do: indices.put_mapping: @@ -835,7 +830,6 @@ setup: type: int8_hnsw m: 20 ef_construction: 90 - confidence_interval: 1.0 - do: indices.get_mapping: @@ -844,7 +838,6 @@ setup: - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.type": "int8_hnsw" } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.m": 20 } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.ef_construction": 90 } - - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.confidence_interval": 1.0 } - do: catch: /Cannot update parameter \[index_options\]/ diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_field_mapping_bwc.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_field_mapping_bwc.yml index a3fc9e55f85e6..f291416dd54da 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_field_mapping_bwc.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_field_mapping_bwc.yml @@ -500,7 +500,6 @@ setup: type: int8_hnsw m: 20 ef_construction: 100 - confidence_interval: 1.0 - do: indices.get_mapping: @@ -509,7 +508,6 @@ setup: - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.type": "int8_hnsw" } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.m": 20 } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.ef_construction": 100 } - - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.confidence_interval": 1.0 } - do: index: @@ -530,7 +528,6 @@ setup: type: int8_hnsw m: 20 ef_construction: 100 - confidence_interval: 1.0 chunks: - text: "these are not the droids you're looking for" embeddings: [ 0.04673296958208084, -0.03237321600317955, -0.02543032355606556, 0.056035321205854416 ] @@ -545,7 +542,6 @@ setup: - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.type": "int8_hnsw" } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.m": 20 } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.ef_construction": 100 } - - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.confidence_interval": 1.0 } --- "Specifying incompatible dense vector index options will fail": @@ -715,7 +711,6 @@ setup: type: int8_hnsw m: 16 ef_construction: 100 - confidence_interval: 1.0 - do: indices.get_mapping: @@ -724,7 +719,6 @@ setup: - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.type": "int8_hnsw" } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.m": 16 } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.ef_construction": 100 } - - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.confidence_interval": 1.0 } - do: indices.put_mapping: @@ -739,7 +733,6 @@ setup: type: int8_hnsw m: 20 ef_construction: 90 - confidence_interval: 1.0 - do: indices.get_mapping: @@ -748,7 +741,6 @@ setup: - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.type": "int8_hnsw" } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.m": 20 } - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.ef_construction": 90 } - - match: { "test-index-options.mappings.properties.semantic_field.index_options.dense_vector.confidence_interval": 1.0 } - do: catch: /Cannot update parameter \[index_options\]/ diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/47_semantic_text_knn.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/47_semantic_text_knn.yml index c7e55aea886c8..1ca181d5d191e 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/47_semantic_text_knn.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/47_semantic_text_knn.yml @@ -600,9 +600,9 @@ setup: - match: { hits.total.value: 2 } - match: { hits.hits.0._id: "doc_1" } - match: { hits.hits.1._id: "doc_3" } - - close_to: { hits.hits.0._score: { value: 0.9990483, error: 1e-4 } } + - close_to: { hits.hits.0._score: { value: 0.9990483, error: 1e-3 } } - not_exists: hits.hits.0.matched_queries - - close_to: { hits.hits.1._score: { value: 0.9439374, error: 1e-4 } } + - close_to: { hits.hits.1._score: { value: 0.9439374, error: 1e-3 } } - not_exists: hits.hits.1.matched_queries - do: @@ -629,7 +629,7 @@ setup: - match: { hits.hits.1._id: "doc_3" } - close_to: { hits.hits.0._score: { value: 4.9952416, error: 1e-3 } } - match: { hits.hits.0.matched_queries: [ "i-like-naming-my-queries" ] } - - close_to: { hits.hits.1._score: { value: 4.719687, error: 1e-3 } } + - close_to: { hits.hits.1._score: { value: 4.719687, error: 1e-2 } } - match: { hits.hits.1.matched_queries: [ "i-like-naming-my-queries" ] } --- diff --git a/x-pack/plugin/old-lucene-versions/src/test/java/org/elasticsearch/xpack/lucene/bwc/codecs/lucene50/BlockPostingsFormatTests.java b/x-pack/plugin/old-lucene-versions/src/test/java/org/elasticsearch/xpack/lucene/bwc/codecs/lucene50/BlockPostingsFormatTests.java index 39f78429d1585..b9ad062729d4d 100644 --- a/x-pack/plugin/old-lucene-versions/src/test/java/org/elasticsearch/xpack/lucene/bwc/codecs/lucene50/BlockPostingsFormatTests.java +++ b/x-pack/plugin/old-lucene-versions/src/test/java/org/elasticsearch/xpack/lucene/bwc/codecs/lucene50/BlockPostingsFormatTests.java @@ -27,7 +27,7 @@ import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.Impact; +import org.apache.lucene.index.FreqAndNormBuffer; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.store.ByteArrayDataInput; @@ -41,12 +41,8 @@ import org.elasticsearch.test.GraalVMThreadsFilter; import org.elasticsearch.xpack.lucene.bwc.codecs.lucene40.blocktree.FieldReader; import org.elasticsearch.xpack.lucene.bwc.codecs.lucene40.blocktree.Stats; -import org.elasticsearch.xpack.lucene.bwc.codecs.lucene50.Lucene50ScoreSkipReader.MutableImpactList; import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; /** Tests BlockPostingsFormat */ @ThreadLeakFilters(filters = { GraalVMThreadsFilter.class }) @@ -85,47 +81,55 @@ public void testFinalBlock() throws Exception { public void testImpactSerialization() throws IOException { // omit norms and omit freqs - doTestImpactSerialization(Collections.singletonList(new Impact(1, 1L))); + FreqAndNormBuffer freqAndNormBuffer = new FreqAndNormBuffer(); + freqAndNormBuffer.add(1, 1L); + doTestImpactSerialization(freqAndNormBuffer); // omit freqs - doTestImpactSerialization(Collections.singletonList(new Impact(1, 42L))); + freqAndNormBuffer = new FreqAndNormBuffer(); + freqAndNormBuffer.add(1, 42L); + doTestImpactSerialization(freqAndNormBuffer); + // omit freqs with very large norms - doTestImpactSerialization(Collections.singletonList(new Impact(1, -100L))); + freqAndNormBuffer = new FreqAndNormBuffer(); + freqAndNormBuffer.add(1, -100); + doTestImpactSerialization(freqAndNormBuffer); // omit norms - doTestImpactSerialization(Collections.singletonList(new Impact(30, 1L))); + freqAndNormBuffer = new FreqAndNormBuffer(); + freqAndNormBuffer.add(30, 1L); + doTestImpactSerialization(freqAndNormBuffer); + // omit norms with large freq - doTestImpactSerialization(Collections.singletonList(new Impact(500, 1L))); + freqAndNormBuffer = new FreqAndNormBuffer(); + freqAndNormBuffer.add(500, 1L); + doTestImpactSerialization(freqAndNormBuffer); // freqs and norms, basic - doTestImpactSerialization( - Arrays.asList( - new Impact(1, 7L), - new Impact(3, 9L), - new Impact(7, 10L), - new Impact(15, 11L), - new Impact(20, 13L), - new Impact(28, 14L) - ) - ); + freqAndNormBuffer = new FreqAndNormBuffer(); + freqAndNormBuffer.add(1, 7L); + freqAndNormBuffer.add(3, 9L); + freqAndNormBuffer.add(7, 10L); + freqAndNormBuffer.add(15, 11L); + freqAndNormBuffer.add(20, 13L); + freqAndNormBuffer.add(28, 14L); + doTestImpactSerialization(freqAndNormBuffer); // freqs and norms, high values - doTestImpactSerialization( - Arrays.asList( - new Impact(2, 2L), - new Impact(10, 10L), - new Impact(12, 50L), - new Impact(50, -100L), - new Impact(1000, -80L), - new Impact(1005, -3L) - ) - ); + freqAndNormBuffer = new FreqAndNormBuffer(); + freqAndNormBuffer.add(2, 2L); + freqAndNormBuffer.add(10, 10L); + freqAndNormBuffer.add(12, 50L); + freqAndNormBuffer.add(50, -100L); + freqAndNormBuffer.add(1000, -80L); + freqAndNormBuffer.add(1005, -3L); + doTestImpactSerialization(freqAndNormBuffer); } - private void doTestImpactSerialization(List impacts) throws IOException { + private void doTestImpactSerialization(FreqAndNormBuffer impacts) throws IOException { CompetitiveImpactAccumulator acc = new CompetitiveImpactAccumulator(); - for (Impact impact : impacts) { - acc.add(impact.freq, impact.norm); + for (int i = 0; i < impacts.size; i++) { + acc.add(impacts.freqs[i], impacts.norms[i]); } try (Directory dir = newDirectory()) { try (IndexOutput out = EndiannessReverserUtil.createOutput(dir, "foo", IOContext.DEFAULT)) { @@ -134,8 +138,12 @@ private void doTestImpactSerialization(List impacts) throws IOException try (IndexInput in = EndiannessReverserUtil.openInput(dir, "foo", IOContext.DEFAULT)) { byte[] b = new byte[Math.toIntExact(in.length())]; in.readBytes(b, 0, b.length); - List impacts2 = Lucene50ScoreSkipReader.readImpacts(new ByteArrayDataInput(b), new MutableImpactList()); - assertEquals(impacts, impacts2); + FreqAndNormBuffer impacts2 = Lucene50ScoreSkipReader.readImpacts(new ByteArrayDataInput(b), new FreqAndNormBuffer()); + assertEquals(impacts.size, impacts2.size); + for (int i = 0; i < impacts.size; i++) { + assertEquals(impacts.freqs[i], impacts2.freqs[i]); + assertEquals(impacts.norms[i], impacts2.norms[i]); + } } } } diff --git a/x-pack/plugin/old-lucene-versions/src/test/java/org/elasticsearch/xpack/lucene/bwc/codecs/lucene50/Lucene50ScoreSkipReader.java b/x-pack/plugin/old-lucene-versions/src/test/java/org/elasticsearch/xpack/lucene/bwc/codecs/lucene50/Lucene50ScoreSkipReader.java index e27e95f2601a2..a8740ba418d0e 100644 --- a/x-pack/plugin/old-lucene-versions/src/test/java/org/elasticsearch/xpack/lucene/bwc/codecs/lucene50/Lucene50ScoreSkipReader.java +++ b/x-pack/plugin/old-lucene-versions/src/test/java/org/elasticsearch/xpack/lucene/bwc/codecs/lucene50/Lucene50ScoreSkipReader.java @@ -19,17 +19,14 @@ */ package org.elasticsearch.xpack.lucene.bwc.codecs.lucene50; -import org.apache.lucene.index.Impact; +import org.apache.lucene.index.FreqAndNormBuffer; import org.apache.lucene.index.Impacts; import org.apache.lucene.store.ByteArrayDataInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.ArrayUtil; import java.io.IOException; -import java.util.AbstractList; import java.util.Arrays; -import java.util.List; -import java.util.RandomAccess; final class Lucene50ScoreSkipReader extends Lucene50SkipReader { @@ -38,7 +35,7 @@ final class Lucene50ScoreSkipReader extends Lucene50SkipReader { private final ByteArrayDataInput badi = new ByteArrayDataInput(); private final Impacts impacts; private int numLevels = 1; - private final MutableImpactList[] perLevelImpacts; + private final FreqAndNormBuffer[] perLevelImpacts; Lucene50ScoreSkipReader( int version, @@ -55,9 +52,10 @@ final class Lucene50ScoreSkipReader extends Lucene50SkipReader { this.impactData = new byte[maxSkipLevels][]; Arrays.fill(impactData, new byte[0]); this.impactDataLength = new int[maxSkipLevels]; - this.perLevelImpacts = new MutableImpactList[maxSkipLevels]; + this.perLevelImpacts = new FreqAndNormBuffer[maxSkipLevels]; for (int i = 0; i < perLevelImpacts.length; ++i) { - perLevelImpacts[i] = new MutableImpactList(); + perLevelImpacts[i] = new FreqAndNormBuffer(); + perLevelImpacts[i].add(Integer.MAX_VALUE, 1L); } impacts = new Impacts() { @@ -72,7 +70,7 @@ public int getDocIdUpTo(int level) { } @Override - public List getImpacts(int level) { + public FreqAndNormBuffer getImpacts(int level) { assert level < numLevels; if (impactDataLength[level] > 0) { badi.reset(impactData[level], 0, impactDataLength[level]); @@ -93,9 +91,9 @@ public int skipTo(int target) throws IOException { // End of postings don't have skip data anymore, so we fill with dummy data // like SlowImpactsEnum. numLevels = 1; - perLevelImpacts[0].length = 1; - perLevelImpacts[0].impacts[0].freq = Integer.MAX_VALUE; - perLevelImpacts[0].impacts[0].norm = 1L; + perLevelImpacts[0].size = 1; + perLevelImpacts[0].freqs[0] = Integer.MAX_VALUE; + perLevelImpacts[0].norms[0] = 1L; impactDataLength[0] = 0; } return result; @@ -115,19 +113,13 @@ protected void readImpacts(int level, IndexInput skipStream) throws IOException impactDataLength[level] = length; } - static MutableImpactList readImpacts(ByteArrayDataInput in, MutableImpactList reuse) { + static FreqAndNormBuffer readImpacts(ByteArrayDataInput in, FreqAndNormBuffer reuse) { int maxNumImpacts = in.length(); // at most one impact per byte - if (reuse.impacts.length < maxNumImpacts) { - int oldLength = reuse.impacts.length; - reuse.impacts = ArrayUtil.grow(reuse.impacts, maxNumImpacts); - for (int i = oldLength; i < reuse.impacts.length; ++i) { - reuse.impacts[i] = new Impact(Integer.MAX_VALUE, 1L); - } - } + reuse.growNoCopy(maxNumImpacts); int freq = 0; long norm = 0; - int length = 0; + int size = 0; while (in.getPosition() < in.length()) { int freqDelta = in.readVInt(); if ((freqDelta & 0x01) != 0) { @@ -141,27 +133,11 @@ static MutableImpactList readImpacts(ByteArrayDataInput in, MutableImpactList re freq += 1 + (freqDelta >>> 1); norm++; } - Impact impact = reuse.impacts[length]; - impact.freq = freq; - impact.norm = norm; - length++; + reuse.freqs[size] = freq; + reuse.norms[size] = norm; + size++; } - reuse.length = length; + reuse.size = size; return reuse; } - - static class MutableImpactList extends AbstractList implements RandomAccess { - int length = 1; - Impact[] impacts = new Impact[] { new Impact(Integer.MAX_VALUE, 1L) }; - - @Override - public Impact get(int index) { - return impacts[index]; - } - - @Override - public int size() { - return length; - } - } } diff --git a/x-pack/plugin/old-lucene-versions/src/test/java/org/elasticsearch/xpack/lucene/bwc/codecs/lucene50/Lucene50SkipWriter.java b/x-pack/plugin/old-lucene-versions/src/test/java/org/elasticsearch/xpack/lucene/bwc/codecs/lucene50/Lucene50SkipWriter.java index 9555f266e0611..6fd7f47a5cb26 100644 --- a/x-pack/plugin/old-lucene-versions/src/test/java/org/elasticsearch/xpack/lucene/bwc/codecs/lucene50/Lucene50SkipWriter.java +++ b/x-pack/plugin/old-lucene-versions/src/test/java/org/elasticsearch/xpack/lucene/bwc/codecs/lucene50/Lucene50SkipWriter.java @@ -20,8 +20,8 @@ package org.elasticsearch.xpack.lucene.bwc.codecs.lucene50; import org.apache.lucene.codecs.CompetitiveImpactAccumulator; +import org.apache.lucene.codecs.Impact; import org.apache.lucene.codecs.MultiLevelSkipListWriter; -import org.apache.lucene.index.Impact; import org.apache.lucene.store.ByteBuffersDataOutput; import org.apache.lucene.store.DataOutput; import org.apache.lucene.store.IndexOutput;