From e7f28e3d6ecfd5c25697b9db8bb20cd78b53c1ad Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 13 Mar 2025 09:40:08 -0400 Subject: [PATCH 1/3] New `vector_rescore` parameter as a quantized index type option (#124581) This adds a new parameter to the quantized index mapping that allows default oversampling and rescoring to occur. This doesn't adjust any of the defaults. It allows it to be configured. When the user provides `rescore_vector: {oversample: }` in the query it will overwrite it. For example, here is how to use it with bbq: ``` PUT rescored_bbq { "mappings": { "properties": { "vector": { "type": "dense_vector", "index_options": { "type": "bbq_hnsw", "rescore_vector": {"oversample": 3.0} } } } } } ``` Then, when querying, it will auto oversample the `k` by `3x` and rerank with the raw vectors. ``` POST _search { "knn": { "query_vector": [...], "field": "vector" } } ``` (cherry picked from commit b2c1c4e0f0b01d475edcdad8199bf4ebf7de73cc) --- docs/changelog/124581.yaml | 5 + .../search.vectors/41_knn_search_bbq_hnsw.yml | 90 ++++++++ .../41_knn_search_byte_quantized.yml | 89 ++++++++ .../41_knn_search_half_byte_quantized.yml | 89 ++++++++ .../search.vectors/42_knn_search_bbq_flat.yml | 90 ++++++++ .../42_knn_search_int4_flat.yml | 90 ++++++++ .../42_knn_search_int8_flat.yml | 90 ++++++++ .../index/mapper/MapperFeatures.java | 5 +- .../vectors/DenseVectorFieldMapper.java | 205 ++++++++++++++---- .../vectors/DenseVectorFieldMapperTests.java | 115 ++++++++++ .../vectors/DenseVectorFieldTypeTests.java | 46 +++- 11 files changed, 856 insertions(+), 58 deletions(-) create mode 100644 docs/changelog/124581.yaml diff --git a/docs/changelog/124581.yaml b/docs/changelog/124581.yaml new file mode 100644 index 0000000000000..cc978981b1efc --- /dev/null +++ b/docs/changelog/124581.yaml @@ -0,0 +1,5 @@ +pr: 124581 +summary: New `vector_rescore` parameter as a quantized index type option +area: Vector Search +type: enhancement +issues: [] diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 4c9d1ef881c6d..9747644a5ba6c 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -250,3 +250,93 @@ setup: index: dynamic_dim_bbq_hnsw body: vector: [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_hnsw + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml index 229d705bc317c..fb45521cb47c6 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml @@ -611,3 +611,92 @@ setup: - match: { hits.hits.0._id: "1"} - match: { hits.hits.1._id: "2"} - match: { hits.hits.2._id: "3"} +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: int8_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: int8_hnsw + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: int8_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int8_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml index baf568762dd17..c2fe78ddbd532 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml @@ -645,3 +645,92 @@ setup: index: dynamic_dim_hnsw_quantized body: vector: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: int4_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: int4_hnsw + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: int4_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int4_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index 0bc111576c2a9..8374b636f1dd6 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -252,3 +252,93 @@ setup: index: dynamic_dim_bbq_flat body: vector: [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_flat + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_flat + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml index 71865de6e0a1c..6dad9ddd26214 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml @@ -408,3 +408,93 @@ setup: - match: { hits.hits.0._score: $rescore_score0 } - match: { hits.hits.1._score: $rescore_score1 } - match: { hits.hits.2._score: $rescore_score2 } +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: int4_rescore_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: int4_flat + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: int4_rescore_flat + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int4_rescore_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int4_rescore_flat + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml index 6b59b8f641ee9..1087b5b264cf8 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml @@ -346,3 +346,93 @@ setup: index: true index_options: type: int8_flat +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: int8_rescore_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: int8_flat + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: int8_rescore_flat + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int8_rescore_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int8_rescore_flat + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index f1f4f9b8faecd..17c4a62a9898f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -17,6 +17,8 @@ import java.util.Set; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.RESCORE_VECTOR_QUANTIZED_VECTOR_MAPPING; + /** * Spec for mapper-related features. */ @@ -90,7 +92,8 @@ public Set getTestFeatures() { ObjectMapper.SUBOBJECTS_FALSE_MAPPING_UPDATE_FIX, UKNOWN_FIELD_MAPPING_UPDATE_ERROR_MESSAGE, DateFieldMapper.INVALID_DATE_FIX, - NPE_ON_DIMS_UPDATE_FIX + NPE_ON_DIMS_UPDATE_FIX, + RESCORE_VECTOR_QUANTIZED_VECTOR_MAPPING ); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 251998c84b8b7..dd3c5b1373a3d 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -73,6 +73,7 @@ import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.search.vectors.VectorSimilarityQuery; import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParser.Token; @@ -116,6 +117,9 @@ public static boolean isNotUnitVector(float magnitude) { public static final IndexVersion NORMALIZE_COSINE = IndexVersions.NORMALIZED_VECTOR_COSINE; public static final IndexVersion DEFAULT_TO_INT8 = DEFAULT_DENSE_VECTOR_TO_INT8_HNSW; public static final IndexVersion LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION = IndexVersions.V_8_9_0; + public static final IndexVersion ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS = IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS; + + public static final NodeFeature RESCORE_VECTOR_QUANTIZED_VECTOR_MAPPING = new NodeFeature("mapper.dense_vector.rescore_vector"); public static final String CONTENT_TYPE = "dense_vector"; public static short MAX_DIMS_COUNT = 4096; // maximum allowed number of dimensions @@ -213,10 +217,11 @@ public Builder(String name, IndexVersion indexVersionCreated) { ? new Int8HnswIndexOptions( Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + null, null ) : null, - (n, c, o) -> o == null ? null : parseIndexOptions(n, o), + (n, c, o) -> o == null ? null : parseIndexOptions(n, o, indexVersionCreated), m -> toType(m).indexOptions, (b, n, v) -> { if (v != null) { @@ -1228,10 +1233,19 @@ public final int hashCode() { } } + abstract static class QuantizedIndexOptions extends IndexOptions { + final RescoreVector rescoreVector; + + QuantizedIndexOptions(VectorIndexType type, RescoreVector rescoreVector) { + super(type); + this.rescoreVector = rescoreVector; + } + } + public enum VectorIndexType { HNSW("hnsw", false) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); if (mNode == null) { @@ -1258,7 +1272,7 @@ public boolean supportsDimension(int dims) { }, INT8_HNSW("int8_hnsw", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); @@ -1274,8 +1288,12 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } + RescoreVector rescoreVector = null; + if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int8HnswIndexOptions(m, efConstruction, confidenceInterval); + return new Int8HnswIndexOptions(m, efConstruction, confidenceInterval, rescoreVector); } @Override @@ -1289,7 +1307,7 @@ public boolean supportsDimension(int dims) { } }, INT4_HNSW("int4_hnsw", true) { - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); @@ -1305,8 +1323,12 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } + RescoreVector rescoreVector = null; + if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int4HnswIndexOptions(m, efConstruction, confidenceInterval); + return new Int4HnswIndexOptions(m, efConstruction, confidenceInterval, rescoreVector); } @Override @@ -1321,7 +1343,7 @@ public boolean supportsDimension(int dims) { }, FLAT("flat", false) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); return new FlatIndexOptions(); } @@ -1338,14 +1360,18 @@ public boolean supportsDimension(int dims) { }, INT8_FLAT("int8_flat", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); Float confidenceInterval = null; if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } + RescoreVector rescoreVector = null; + if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int8FlatIndexOptions(confidenceInterval); + return new Int8FlatIndexOptions(confidenceInterval, rescoreVector); } @Override @@ -1360,14 +1386,18 @@ public boolean supportsDimension(int dims) { }, INT4_FLAT("int4_flat", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); Float confidenceInterval = null; if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } + RescoreVector rescoreVector = null; + if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int4FlatIndexOptions(confidenceInterval); + return new Int4FlatIndexOptions(confidenceInterval, rescoreVector); } @Override @@ -1382,7 +1412,7 @@ public boolean supportsDimension(int dims) { }, BBQ_HNSW("bbq_hnsw", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); if (mNode == null) { @@ -1393,8 +1423,12 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti } int m = XContentMapValues.nodeIntegerValue(mNode); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + RescoreVector rescoreVector = null; + if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new BBQHnswIndexOptions(m, efConstruction); + return new BBQHnswIndexOptions(m, efConstruction, rescoreVector); } @Override @@ -1409,9 +1443,13 @@ public boolean supportsDimension(int dims) { }, BBQ_FLAT("bbq_flat", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { + RescoreVector rescoreVector = null; + if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new BBQFlatIndexOptions(); + return new BBQFlatIndexOptions(rescoreVector); } @Override @@ -1437,7 +1475,7 @@ static Optional fromString(String type) { this.quantized = quantized; } - abstract IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap); + abstract IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion); public abstract boolean supportsElementType(ElementType elementType); @@ -1453,11 +1491,11 @@ public String toString() { } } - static class Int8FlatIndexOptions extends IndexOptions { + static class Int8FlatIndexOptions extends QuantizedIndexOptions { private final Float confidenceInterval; - Int8FlatIndexOptions(Float confidenceInterval) { - super(VectorIndexType.INT8_FLAT); + Int8FlatIndexOptions(Float confidenceInterval, RescoreVector rescoreVector) { + super(VectorIndexType.INT8_FLAT, rescoreVector); this.confidenceInterval = confidenceInterval; } @@ -1468,6 +1506,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (confidenceInterval != null) { builder.field("confidence_interval", confidenceInterval); } + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1481,12 +1522,12 @@ KnnVectorsFormat getVectorsFormat(ElementType elementType) { @Override boolean doEquals(IndexOptions o) { Int8FlatIndexOptions that = (Int8FlatIndexOptions) o; - return Objects.equals(confidenceInterval, that.confidenceInterval); + return Objects.equals(confidenceInterval, that.confidenceInterval) && Objects.equals(rescoreVector, that.rescoreVector); } @Override int doHashCode() { - return Objects.hash(confidenceInterval); + return Objects.hash(confidenceInterval, rescoreVector); } @Override @@ -1537,13 +1578,13 @@ public int doHashCode() { } } - static class Int4HnswIndexOptions extends IndexOptions { + static class Int4HnswIndexOptions extends QuantizedIndexOptions { private final int m; private final int efConstruction; private final float confidenceInterval; - Int4HnswIndexOptions(int m, int efConstruction, Float confidenceInterval) { - super(VectorIndexType.INT4_HNSW); + Int4HnswIndexOptions(int m, int efConstruction, Float confidenceInterval, RescoreVector rescoreVector) { + super(VectorIndexType.INT4_HNSW, rescoreVector); this.m = m; this.efConstruction = efConstruction; // The default confidence interval for int4 is dynamic quantiles, this provides the best relevancy and is @@ -1564,6 +1605,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field("m", m); builder.field("ef_construction", efConstruction); builder.field("confidence_interval", confidenceInterval); + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1571,12 +1615,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public boolean doEquals(IndexOptions o) { Int4HnswIndexOptions that = (Int4HnswIndexOptions) o; - return m == that.m && efConstruction == that.efConstruction && Objects.equals(confidenceInterval, that.confidenceInterval); + return m == that.m + && efConstruction == that.efConstruction + && Objects.equals(confidenceInterval, that.confidenceInterval) + && Objects.equals(rescoreVector, that.rescoreVector); } @Override public int doHashCode() { - return Objects.hash(m, efConstruction, confidenceInterval); + return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector); } @Override @@ -1589,6 +1636,8 @@ public String toString() { + efConstruction + ", confidence_interval=" + confidenceInterval + + ", rescore_vector=" + + (rescoreVector == null ? "none" : rescoreVector) + "}"; } @@ -1605,11 +1654,11 @@ boolean updatableTo(IndexOptions update) { } } - static class Int4FlatIndexOptions extends IndexOptions { + static class Int4FlatIndexOptions extends QuantizedIndexOptions { private final float confidenceInterval; - Int4FlatIndexOptions(Float confidenceInterval) { - super(VectorIndexType.INT4_FLAT); + Int4FlatIndexOptions(Float confidenceInterval, RescoreVector rescoreVector) { + super(VectorIndexType.INT4_FLAT, rescoreVector); // The default confidence interval for int4 is dynamic quantiles, this provides the best relevancy and is // effectively required for int4 to behave well across a wide range of data. this.confidenceInterval = confidenceInterval == null ? 0f : confidenceInterval; @@ -1626,6 +1675,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field("type", type); builder.field("confidence_interval", confidenceInterval); + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1635,17 +1687,17 @@ public boolean doEquals(IndexOptions o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Int4FlatIndexOptions that = (Int4FlatIndexOptions) o; - return Objects.equals(confidenceInterval, that.confidenceInterval); + return Objects.equals(confidenceInterval, that.confidenceInterval) && Objects.equals(rescoreVector, that.rescoreVector); } @Override public int doHashCode() { - return Objects.hash(confidenceInterval); + return Objects.hash(confidenceInterval, rescoreVector); } @Override public String toString() { - return "{type=" + type + ", confidence_interval=" + confidenceInterval + "}"; + return "{type=" + type + ", confidence_interval=" + confidenceInterval + ", rescore_vector=" + rescoreVector + "}"; } @Override @@ -1659,13 +1711,13 @@ boolean updatableTo(IndexOptions update) { } - static class Int8HnswIndexOptions extends IndexOptions { + static class Int8HnswIndexOptions extends QuantizedIndexOptions { private final int m; private final int efConstruction; private final Float confidenceInterval; - Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval) { - super(VectorIndexType.INT8_HNSW); + Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval, RescoreVector rescoreVector) { + super(VectorIndexType.INT8_HNSW, rescoreVector); this.m = m; this.efConstruction = efConstruction; this.confidenceInterval = confidenceInterval; @@ -1686,6 +1738,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (confidenceInterval != null) { builder.field("confidence_interval", confidenceInterval); } + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1695,12 +1750,15 @@ public boolean doEquals(IndexOptions o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Int8HnswIndexOptions that = (Int8HnswIndexOptions) o; - return m == that.m && efConstruction == that.efConstruction && Objects.equals(confidenceInterval, that.confidenceInterval); + return m == that.m + && efConstruction == that.efConstruction + && Objects.equals(confidenceInterval, that.confidenceInterval) + && Objects.equals(rescoreVector, that.rescoreVector); } @Override public int doHashCode() { - return Objects.hash(m, efConstruction, confidenceInterval); + return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector); } @Override @@ -1713,6 +1771,8 @@ public String toString() { + efConstruction + ", confidence_interval=" + confidenceInterval + + ", rescore_vector=" + + (rescoreVector == null ? "none" : rescoreVector) + "}"; } @@ -1794,12 +1854,12 @@ public String toString() { } } - static class BBQHnswIndexOptions extends IndexOptions { + static class BBQHnswIndexOptions extends QuantizedIndexOptions { private final int m; private final int efConstruction; - BBQHnswIndexOptions(int m, int efConstruction) { - super(VectorIndexType.BBQ_HNSW); + BBQHnswIndexOptions(int m, int efConstruction, RescoreVector rescoreVector) { + super(VectorIndexType.BBQ_HNSW, rescoreVector); this.m = m; this.efConstruction = efConstruction; } @@ -1818,12 +1878,12 @@ boolean updatableTo(IndexOptions update) { @Override boolean doEquals(IndexOptions other) { BBQHnswIndexOptions that = (BBQHnswIndexOptions) other; - return m == that.m && efConstruction == that.efConstruction; + return m == that.m && efConstruction == that.efConstruction && Objects.equals(rescoreVector, that.rescoreVector); } @Override int doHashCode() { - return Objects.hash(m, efConstruction); + return Objects.hash(m, efConstruction, rescoreVector); } @Override @@ -1832,6 +1892,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field("type", type); builder.field("m", m); builder.field("ef_construction", efConstruction); + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1845,11 +1908,11 @@ public void validateDimension(int dim) { } } - static class BBQFlatIndexOptions extends IndexOptions { + static class BBQFlatIndexOptions extends QuantizedIndexOptions { private final int CLASS_NAME_HASH = this.getClass().getName().hashCode(); - BBQFlatIndexOptions() { - super(VectorIndexType.BBQ_FLAT); + BBQFlatIndexOptions(RescoreVector rescoreVector) { + super(VectorIndexType.BBQ_FLAT, rescoreVector); } @Override @@ -1877,6 +1940,9 @@ int doHashCode() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field("type", type); + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1890,6 +1956,41 @@ public void validateDimension(int dim) { } } + record RescoreVector(float oversample) implements ToXContentObject { + static final String NAME = "rescore_vector"; + static final String OVERSAMPLE = "oversample"; + + static RescoreVector fromIndexOptions(Map indexOptionsMap) { + Object rescoreVectorNode = indexOptionsMap.remove(NAME); + if (rescoreVectorNode == null) { + return null; + } + Map mappedNode = XContentMapValues.nodeMapValue(rescoreVectorNode, NAME); + Object oversampleNode = mappedNode.get(OVERSAMPLE); + if (oversampleNode == null) { + throw new IllegalArgumentException("Invalid rescore_vector value. Missing required field " + OVERSAMPLE); + } + return new RescoreVector((float) XContentMapValues.nodeDoubleValue(oversampleNode)); + } + + RescoreVector { + if (oversample < 1) { + throw new IllegalArgumentException("oversample must be greater than 1"); + } + if (oversample > 10) { + throw new IllegalArgumentException("oversample must be less than or equal to 10"); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + builder.field(OVERSAMPLE, oversample); + builder.endObject(); + return builder; + } + } + public static final TypeParser PARSER = new TypeParser( (n, c) -> new Builder(n, c.indexVersionCreated()), notInMultiFields(CONTENT_TYPE) @@ -2105,7 +2206,7 @@ private Query createKnnFloatQuery( float[] queryVector, int k, int numCands, - Float oversample, + Float queryOversample, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -2127,6 +2228,14 @@ && isNotUnitVector(squaredMagnitude)) { } int adjustedK = k; + // By default utilize the quantized oversample is configured + // allow the user provided at query time overwrite + Float oversample = queryOversample; + if (oversample == null + && indexOptions instanceof QuantizedIndexOptions quantizedIndexOptions + && quantizedIndexOptions.rescoreVector != null) { + oversample = quantizedIndexOptions.rescoreVector.oversample; + } boolean rescore = needsRescore(oversample); if (rescore) { // Will get k * oversample for rescoring, and get the top k @@ -2322,7 +2431,7 @@ public FieldMapper.Builder getMergeBuilder() { return new Builder(leafName(), indexCreatedVersion).init(this); } - private static IndexOptions parseIndexOptions(String fieldName, Object propNode) { + private static IndexOptions parseIndexOptions(String fieldName, Object propNode, IndexVersion indexVersion) { @SuppressWarnings("unchecked") Map indexOptionsMap = (Map) propNode; Object typeNode = indexOptionsMap.remove("type"); @@ -2335,7 +2444,7 @@ private static IndexOptions parseIndexOptions(String fieldName, Object propNode) throw new MapperParsingException("Unknown vector index options type [" + type + "] for field [" + fieldName + "]"); } VectorIndexType parsedType = vectorIndexType.get(); - return parsedType.parseIndexOptions(fieldName, indexOptionsMap); + return parsedType.parseIndexOptions(fieldName, indexOptionsMap, indexVersion); } /** diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 6e13faa99b4b5..898c23c90d305 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -59,6 +59,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Set; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; @@ -883,6 +884,120 @@ protected void assertExistsQuery(MappedFieldType fieldType, Query query, LuceneD @Override public void testAggregatableConsistency() {} + public void testRescoreVectorForNonQuantized() { + for (String indexType : List.of("hnsw", "flat")) { + Exception e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", 1.5f)) + .endObject() + ) + ) + ); + e.getMessage().contains("Mapping definition for [field] has unsupported parameters:"); + } + } + + public void tesetRescoreVectorOldIndexVersion() { + IndexVersion incompatibleVersion = IndexVersionUtils.randomVersionBetween( + random(), + IndexVersionUtils.getLowestReadCompatibleVersion(), + IndexVersionUtils.getPreviousVersion(DenseVectorFieldMapper.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS) + ); + for (String indexType : List.of("int8_hnsw", "int8_flat", "int4_hnsw", "int4_flat", "bbq_hnsw", "bbq_flat")) { + expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + incompatibleVersion, + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", 1.5f)) + .endObject() + ) + ) + ); + } + } + + public void testInvalidRescoreVector() { + for (String indexType : List.of("int8_hnsw", "int8_flat", "int4_hnsw", "int4_flat", "bbq_hnsw", "bbq_flat")) { + Exception e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("foo", 1.5f)) + .endObject() + ) + ) + ); + e.getMessage().contains("Invalid rescore_vector value. Missing required field oversample"); + e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", "foo")) + .endObject() + ) + ) + ); + e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", 0.1f)) + .endObject() + ) + ) + ); + e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of()) + .endObject() + ) + ) + ); + e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", 10.1f)) + .endObject() + ) + ) + ); + } + } + public void testDims() { { Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 5c067cb2d0a27..e98038b7a0759 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -49,6 +49,10 @@ public DenseVectorFieldTypeTests() { this.indexed = randomBoolean(); } + private static DenseVectorFieldMapper.RescoreVector randomRescoreVector() { + return new DenseVectorFieldMapper.RescoreVector(randomFloatBetween(1.0F, 10.0F, false)); + } + private DenseVectorFieldMapper.IndexOptions randomIndexOptionsNonQuantized() { return randomFrom( new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), @@ -62,18 +66,30 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() { new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.Int4HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.FlatIndexOptions(), - new DenseVectorFieldMapper.Int8FlatIndexOptions(randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))), - new DenseVectorFieldMapper.Int4FlatIndexOptions(randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))), - new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), - new DenseVectorFieldMapper.BBQFlatIndexOptions() + new DenseVectorFieldMapper.Int8FlatIndexOptions( + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ), + new DenseVectorFieldMapper.Int4FlatIndexOptions( + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ), + new DenseVectorFieldMapper.BBQHnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ), + new DenseVectorFieldMapper.BBQFlatIndexOptions(randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())) ); } @@ -82,14 +98,20 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsHnswQuantized() { new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.Int4HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), - new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)) + new DenseVectorFieldMapper.BBQHnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ) ); } @@ -195,6 +217,9 @@ public void testCreateNestedKnnQuery() { queryVector[i] = randomFloat(); } Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, producer); + if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { + query = rescoreKnnVectorQuery.innerQuery(); + } assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); } { @@ -346,6 +371,9 @@ public void testCreateKnnQueryMaxDims() { queryVector[i] = randomFloat(); } Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null); + if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { + query = rescoreKnnVectorQuery.innerQuery(); + } assertThat(query, instanceOf(KnnFloatVectorQuery.class)); } From 17084ddd8fb3ba9cb9690b776c7ec76c34d9a995 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 2 May 2025 13:51:20 -0400 Subject: [PATCH 2/3] Adds new BWC version for 8.19 backport of (#124581) (#127647) --- .../org/elasticsearch/index/IndexVersions.java | 1 + .../mapper/vectors/DenseVectorFieldMapper.java | 16 +++++++++------- .../vectors/DenseVectorFieldMapperTests.java | 4 ++-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index 455255dc15c0f..0080572ec948c 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -128,6 +128,7 @@ private static IndexVersion def(int id, Version luceneVersion) { public static final IndexVersion LOGSB_OPTIONAL_SORTING_ON_HOST_NAME_BACKPORT = def(8_525_0_00, Version.LUCENE_9_12_1); public static final IndexVersion USE_SYNTHETIC_SOURCE_FOR_RECOVERY_BY_DEFAULT_BACKPORT = def(8_526_0_00, Version.LUCENE_9_12_1); public static final IndexVersion SYNTHETIC_SOURCE_STORE_ARRAYS_NATIVELY = def(8_527_0_00, Version.LUCENE_9_12_1); + public static final IndexVersion ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS = def(8_528_0_00, Version.LUCENE_9_12_1); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index dd3c5b1373a3d..e6cd3d15d01b3 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -111,13 +111,15 @@ public static boolean isNotUnitVector(float magnitude) { public static final NodeFeature INT4_QUANTIZATION = new NodeFeature("mapper.vectors.int4_quantization", true); public static final NodeFeature BIT_VECTORS = new NodeFeature("mapper.vectors.bit_vectors", true); public static final NodeFeature BBQ_FORMAT = new NodeFeature("mapper.vectors.bbq", true); + private static boolean hasRescoreIndexVersion(IndexVersion version) { + return version.onOrAfter(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS); + } public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersions.V_7_5_0; public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersions.FIRST_DETACHED_INDEX_VERSION; public static final IndexVersion NORMALIZE_COSINE = IndexVersions.NORMALIZED_VECTOR_COSINE; public static final IndexVersion DEFAULT_TO_INT8 = DEFAULT_DENSE_VECTOR_TO_INT8_HNSW; public static final IndexVersion LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION = IndexVersions.V_8_9_0; - public static final IndexVersion ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS = IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS; public static final NodeFeature RESCORE_VECTOR_QUANTIZED_VECTOR_MAPPING = new NodeFeature("mapper.dense_vector.rescore_vector"); @@ -1289,7 +1291,7 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } RescoreVector rescoreVector = null; - if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + if (hasRescoreIndexVersion(indexVersion)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); @@ -1324,7 +1326,7 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } RescoreVector rescoreVector = null; - if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + if (hasRescoreIndexVersion(indexVersion)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); @@ -1367,7 +1369,7 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } RescoreVector rescoreVector = null; - if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + if (hasRescoreIndexVersion(indexVersion)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); @@ -1393,7 +1395,7 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } RescoreVector rescoreVector = null; - if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + if (hasRescoreIndexVersion(indexVersion)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); @@ -1424,7 +1426,7 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti int m = XContentMapValues.nodeIntegerValue(mNode); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); RescoreVector rescoreVector = null; - if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + if (hasRescoreIndexVersion(indexVersion)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); @@ -1445,7 +1447,7 @@ public boolean supportsDimension(int dims) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { RescoreVector rescoreVector = null; - if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + if (hasRescoreIndexVersion(indexVersion)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 898c23c90d305..27c64de80b43f 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -906,8 +906,8 @@ public void testRescoreVectorForNonQuantized() { public void tesetRescoreVectorOldIndexVersion() { IndexVersion incompatibleVersion = IndexVersionUtils.randomVersionBetween( random(), - IndexVersionUtils.getLowestReadCompatibleVersion(), - IndexVersionUtils.getPreviousVersion(DenseVectorFieldMapper.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS) + IndexVersions.V_8_0_0, + IndexVersionUtils.getPreviousVersion(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS) ); for (String indexType : List.of("int8_hnsw", "int8_flat", "int4_hnsw", "int4_flat", "bbq_hnsw", "bbq_flat")) { expectThrows( From 8fb9131fdae1c7b36cb02dcc324374227e072e87 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 2 May 2025 18:07:23 +0000 Subject: [PATCH 3/3] [CI] Auto commit changes from spotless --- .../index/mapper/vectors/DenseVectorFieldMapper.java | 1 + 1 file changed, 1 insertion(+) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index e6cd3d15d01b3..2c4c843e429cb 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -111,6 +111,7 @@ public static boolean isNotUnitVector(float magnitude) { public static final NodeFeature INT4_QUANTIZATION = new NodeFeature("mapper.vectors.int4_quantization", true); public static final NodeFeature BIT_VECTORS = new NodeFeature("mapper.vectors.bit_vectors", true); public static final NodeFeature BBQ_FORMAT = new NodeFeature("mapper.vectors.bbq", true); + private static boolean hasRescoreIndexVersion(IndexVersion version) { return version.onOrAfter(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS); }