diff --git a/docs/changelog/124581.yaml b/docs/changelog/124581.yaml new file mode 100644 index 0000000000000..cc978981b1efc --- /dev/null +++ b/docs/changelog/124581.yaml @@ -0,0 +1,5 @@ +pr: 124581 +summary: New `vector_rescore` parameter as a quantized index type option +area: Vector Search +type: enhancement +issues: [] diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 4c9d1ef881c6d..9747644a5ba6c 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -250,3 +250,93 @@ setup: index: dynamic_dim_bbq_hnsw body: vector: [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_hnsw + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml index 229d705bc317c..fb45521cb47c6 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml @@ -611,3 +611,92 @@ setup: - match: { hits.hits.0._id: "1"} - match: { hits.hits.1._id: "2"} - match: { hits.hits.2._id: "3"} +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: int8_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: int8_hnsw + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: int8_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int8_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml index baf568762dd17..c2fe78ddbd532 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml @@ -645,3 +645,92 @@ setup: index: dynamic_dim_hnsw_quantized body: vector: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: int4_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: int4_hnsw + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: int4_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int4_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index 0bc111576c2a9..8374b636f1dd6 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -252,3 +252,93 @@ setup: index: dynamic_dim_bbq_flat body: vector: [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_flat + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_flat + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml index 71865de6e0a1c..6dad9ddd26214 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml @@ -408,3 +408,93 @@ setup: - match: { hits.hits.0._score: $rescore_score0 } - match: { hits.hits.1._score: $rescore_score1 } - match: { hits.hits.2._score: $rescore_score2 } +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: int4_rescore_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: int4_flat + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: int4_rescore_flat + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int4_rescore_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int4_rescore_flat + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml index 6b59b8f641ee9..1087b5b264cf8 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml @@ -346,3 +346,93 @@ setup: index: true index_options: type: int8_flat +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: int8_rescore_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: int8_flat + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: int8_rescore_flat + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int8_rescore_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int8_rescore_flat + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index 455255dc15c0f..0080572ec948c 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -128,6 +128,7 @@ private static IndexVersion def(int id, Version luceneVersion) { public static final IndexVersion LOGSB_OPTIONAL_SORTING_ON_HOST_NAME_BACKPORT = def(8_525_0_00, Version.LUCENE_9_12_1); public static final IndexVersion USE_SYNTHETIC_SOURCE_FOR_RECOVERY_BY_DEFAULT_BACKPORT = def(8_526_0_00, Version.LUCENE_9_12_1); public static final IndexVersion SYNTHETIC_SOURCE_STORE_ARRAYS_NATIVELY = def(8_527_0_00, Version.LUCENE_9_12_1); + public static final IndexVersion ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS = def(8_528_0_00, Version.LUCENE_9_12_1); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index f1f4f9b8faecd..17c4a62a9898f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -17,6 +17,8 @@ import java.util.Set; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.RESCORE_VECTOR_QUANTIZED_VECTOR_MAPPING; + /** * Spec for mapper-related features. */ @@ -90,7 +92,8 @@ public Set getTestFeatures() { ObjectMapper.SUBOBJECTS_FALSE_MAPPING_UPDATE_FIX, UKNOWN_FIELD_MAPPING_UPDATE_ERROR_MESSAGE, DateFieldMapper.INVALID_DATE_FIX, - NPE_ON_DIMS_UPDATE_FIX + NPE_ON_DIMS_UPDATE_FIX, + RESCORE_VECTOR_QUANTIZED_VECTOR_MAPPING ); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 251998c84b8b7..2c4c843e429cb 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -73,6 +73,7 @@ import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.search.vectors.VectorSimilarityQuery; import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParser.Token; @@ -111,12 +112,18 @@ public static boolean isNotUnitVector(float magnitude) { public static final NodeFeature BIT_VECTORS = new NodeFeature("mapper.vectors.bit_vectors", true); public static final NodeFeature BBQ_FORMAT = new NodeFeature("mapper.vectors.bbq", true); + private static boolean hasRescoreIndexVersion(IndexVersion version) { + return version.onOrAfter(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS); + } + public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersions.V_7_5_0; public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersions.FIRST_DETACHED_INDEX_VERSION; public static final IndexVersion NORMALIZE_COSINE = IndexVersions.NORMALIZED_VECTOR_COSINE; public static final IndexVersion DEFAULT_TO_INT8 = DEFAULT_DENSE_VECTOR_TO_INT8_HNSW; public static final IndexVersion LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION = IndexVersions.V_8_9_0; + public static final NodeFeature RESCORE_VECTOR_QUANTIZED_VECTOR_MAPPING = new NodeFeature("mapper.dense_vector.rescore_vector"); + public static final String CONTENT_TYPE = "dense_vector"; public static short MAX_DIMS_COUNT = 4096; // maximum allowed number of dimensions public static int MAX_DIMS_COUNT_BIT = 4096 * Byte.SIZE; // maximum allowed number of dimensions @@ -213,10 +220,11 @@ public Builder(String name, IndexVersion indexVersionCreated) { ? new Int8HnswIndexOptions( Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + null, null ) : null, - (n, c, o) -> o == null ? null : parseIndexOptions(n, o), + (n, c, o) -> o == null ? null : parseIndexOptions(n, o, indexVersionCreated), m -> toType(m).indexOptions, (b, n, v) -> { if (v != null) { @@ -1228,10 +1236,19 @@ public final int hashCode() { } } + abstract static class QuantizedIndexOptions extends IndexOptions { + final RescoreVector rescoreVector; + + QuantizedIndexOptions(VectorIndexType type, RescoreVector rescoreVector) { + super(type); + this.rescoreVector = rescoreVector; + } + } + public enum VectorIndexType { HNSW("hnsw", false) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); if (mNode == null) { @@ -1258,7 +1275,7 @@ public boolean supportsDimension(int dims) { }, INT8_HNSW("int8_hnsw", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); @@ -1274,8 +1291,12 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } + RescoreVector rescoreVector = null; + if (hasRescoreIndexVersion(indexVersion)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int8HnswIndexOptions(m, efConstruction, confidenceInterval); + return new Int8HnswIndexOptions(m, efConstruction, confidenceInterval, rescoreVector); } @Override @@ -1289,7 +1310,7 @@ public boolean supportsDimension(int dims) { } }, INT4_HNSW("int4_hnsw", true) { - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); @@ -1305,8 +1326,12 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } + RescoreVector rescoreVector = null; + if (hasRescoreIndexVersion(indexVersion)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int4HnswIndexOptions(m, efConstruction, confidenceInterval); + return new Int4HnswIndexOptions(m, efConstruction, confidenceInterval, rescoreVector); } @Override @@ -1321,7 +1346,7 @@ public boolean supportsDimension(int dims) { }, FLAT("flat", false) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); return new FlatIndexOptions(); } @@ -1338,14 +1363,18 @@ public boolean supportsDimension(int dims) { }, INT8_FLAT("int8_flat", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); Float confidenceInterval = null; if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } + RescoreVector rescoreVector = null; + if (hasRescoreIndexVersion(indexVersion)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int8FlatIndexOptions(confidenceInterval); + return new Int8FlatIndexOptions(confidenceInterval, rescoreVector); } @Override @@ -1360,14 +1389,18 @@ public boolean supportsDimension(int dims) { }, INT4_FLAT("int4_flat", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); Float confidenceInterval = null; if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } + RescoreVector rescoreVector = null; + if (hasRescoreIndexVersion(indexVersion)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int4FlatIndexOptions(confidenceInterval); + return new Int4FlatIndexOptions(confidenceInterval, rescoreVector); } @Override @@ -1382,7 +1415,7 @@ public boolean supportsDimension(int dims) { }, BBQ_HNSW("bbq_hnsw", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); if (mNode == null) { @@ -1393,8 +1426,12 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti } int m = XContentMapValues.nodeIntegerValue(mNode); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + RescoreVector rescoreVector = null; + if (hasRescoreIndexVersion(indexVersion)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new BBQHnswIndexOptions(m, efConstruction); + return new BBQHnswIndexOptions(m, efConstruction, rescoreVector); } @Override @@ -1409,9 +1446,13 @@ public boolean supportsDimension(int dims) { }, BBQ_FLAT("bbq_flat", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { + RescoreVector rescoreVector = null; + if (hasRescoreIndexVersion(indexVersion)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new BBQFlatIndexOptions(); + return new BBQFlatIndexOptions(rescoreVector); } @Override @@ -1437,7 +1478,7 @@ static Optional fromString(String type) { this.quantized = quantized; } - abstract IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap); + abstract IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion); public abstract boolean supportsElementType(ElementType elementType); @@ -1453,11 +1494,11 @@ public String toString() { } } - static class Int8FlatIndexOptions extends IndexOptions { + static class Int8FlatIndexOptions extends QuantizedIndexOptions { private final Float confidenceInterval; - Int8FlatIndexOptions(Float confidenceInterval) { - super(VectorIndexType.INT8_FLAT); + Int8FlatIndexOptions(Float confidenceInterval, RescoreVector rescoreVector) { + super(VectorIndexType.INT8_FLAT, rescoreVector); this.confidenceInterval = confidenceInterval; } @@ -1468,6 +1509,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (confidenceInterval != null) { builder.field("confidence_interval", confidenceInterval); } + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1481,12 +1525,12 @@ KnnVectorsFormat getVectorsFormat(ElementType elementType) { @Override boolean doEquals(IndexOptions o) { Int8FlatIndexOptions that = (Int8FlatIndexOptions) o; - return Objects.equals(confidenceInterval, that.confidenceInterval); + return Objects.equals(confidenceInterval, that.confidenceInterval) && Objects.equals(rescoreVector, that.rescoreVector); } @Override int doHashCode() { - return Objects.hash(confidenceInterval); + return Objects.hash(confidenceInterval, rescoreVector); } @Override @@ -1537,13 +1581,13 @@ public int doHashCode() { } } - static class Int4HnswIndexOptions extends IndexOptions { + static class Int4HnswIndexOptions extends QuantizedIndexOptions { private final int m; private final int efConstruction; private final float confidenceInterval; - Int4HnswIndexOptions(int m, int efConstruction, Float confidenceInterval) { - super(VectorIndexType.INT4_HNSW); + Int4HnswIndexOptions(int m, int efConstruction, Float confidenceInterval, RescoreVector rescoreVector) { + super(VectorIndexType.INT4_HNSW, rescoreVector); this.m = m; this.efConstruction = efConstruction; // The default confidence interval for int4 is dynamic quantiles, this provides the best relevancy and is @@ -1564,6 +1608,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field("m", m); builder.field("ef_construction", efConstruction); builder.field("confidence_interval", confidenceInterval); + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1571,12 +1618,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public boolean doEquals(IndexOptions o) { Int4HnswIndexOptions that = (Int4HnswIndexOptions) o; - return m == that.m && efConstruction == that.efConstruction && Objects.equals(confidenceInterval, that.confidenceInterval); + return m == that.m + && efConstruction == that.efConstruction + && Objects.equals(confidenceInterval, that.confidenceInterval) + && Objects.equals(rescoreVector, that.rescoreVector); } @Override public int doHashCode() { - return Objects.hash(m, efConstruction, confidenceInterval); + return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector); } @Override @@ -1589,6 +1639,8 @@ public String toString() { + efConstruction + ", confidence_interval=" + confidenceInterval + + ", rescore_vector=" + + (rescoreVector == null ? "none" : rescoreVector) + "}"; } @@ -1605,11 +1657,11 @@ boolean updatableTo(IndexOptions update) { } } - static class Int4FlatIndexOptions extends IndexOptions { + static class Int4FlatIndexOptions extends QuantizedIndexOptions { private final float confidenceInterval; - Int4FlatIndexOptions(Float confidenceInterval) { - super(VectorIndexType.INT4_FLAT); + Int4FlatIndexOptions(Float confidenceInterval, RescoreVector rescoreVector) { + super(VectorIndexType.INT4_FLAT, rescoreVector); // The default confidence interval for int4 is dynamic quantiles, this provides the best relevancy and is // effectively required for int4 to behave well across a wide range of data. this.confidenceInterval = confidenceInterval == null ? 0f : confidenceInterval; @@ -1626,6 +1678,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field("type", type); builder.field("confidence_interval", confidenceInterval); + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1635,17 +1690,17 @@ public boolean doEquals(IndexOptions o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Int4FlatIndexOptions that = (Int4FlatIndexOptions) o; - return Objects.equals(confidenceInterval, that.confidenceInterval); + return Objects.equals(confidenceInterval, that.confidenceInterval) && Objects.equals(rescoreVector, that.rescoreVector); } @Override public int doHashCode() { - return Objects.hash(confidenceInterval); + return Objects.hash(confidenceInterval, rescoreVector); } @Override public String toString() { - return "{type=" + type + ", confidence_interval=" + confidenceInterval + "}"; + return "{type=" + type + ", confidence_interval=" + confidenceInterval + ", rescore_vector=" + rescoreVector + "}"; } @Override @@ -1659,13 +1714,13 @@ boolean updatableTo(IndexOptions update) { } - static class Int8HnswIndexOptions extends IndexOptions { + static class Int8HnswIndexOptions extends QuantizedIndexOptions { private final int m; private final int efConstruction; private final Float confidenceInterval; - Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval) { - super(VectorIndexType.INT8_HNSW); + Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval, RescoreVector rescoreVector) { + super(VectorIndexType.INT8_HNSW, rescoreVector); this.m = m; this.efConstruction = efConstruction; this.confidenceInterval = confidenceInterval; @@ -1686,6 +1741,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (confidenceInterval != null) { builder.field("confidence_interval", confidenceInterval); } + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1695,12 +1753,15 @@ public boolean doEquals(IndexOptions o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Int8HnswIndexOptions that = (Int8HnswIndexOptions) o; - return m == that.m && efConstruction == that.efConstruction && Objects.equals(confidenceInterval, that.confidenceInterval); + return m == that.m + && efConstruction == that.efConstruction + && Objects.equals(confidenceInterval, that.confidenceInterval) + && Objects.equals(rescoreVector, that.rescoreVector); } @Override public int doHashCode() { - return Objects.hash(m, efConstruction, confidenceInterval); + return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector); } @Override @@ -1713,6 +1774,8 @@ public String toString() { + efConstruction + ", confidence_interval=" + confidenceInterval + + ", rescore_vector=" + + (rescoreVector == null ? "none" : rescoreVector) + "}"; } @@ -1794,12 +1857,12 @@ public String toString() { } } - static class BBQHnswIndexOptions extends IndexOptions { + static class BBQHnswIndexOptions extends QuantizedIndexOptions { private final int m; private final int efConstruction; - BBQHnswIndexOptions(int m, int efConstruction) { - super(VectorIndexType.BBQ_HNSW); + BBQHnswIndexOptions(int m, int efConstruction, RescoreVector rescoreVector) { + super(VectorIndexType.BBQ_HNSW, rescoreVector); this.m = m; this.efConstruction = efConstruction; } @@ -1818,12 +1881,12 @@ boolean updatableTo(IndexOptions update) { @Override boolean doEquals(IndexOptions other) { BBQHnswIndexOptions that = (BBQHnswIndexOptions) other; - return m == that.m && efConstruction == that.efConstruction; + return m == that.m && efConstruction == that.efConstruction && Objects.equals(rescoreVector, that.rescoreVector); } @Override int doHashCode() { - return Objects.hash(m, efConstruction); + return Objects.hash(m, efConstruction, rescoreVector); } @Override @@ -1832,6 +1895,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field("type", type); builder.field("m", m); builder.field("ef_construction", efConstruction); + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1845,11 +1911,11 @@ public void validateDimension(int dim) { } } - static class BBQFlatIndexOptions extends IndexOptions { + static class BBQFlatIndexOptions extends QuantizedIndexOptions { private final int CLASS_NAME_HASH = this.getClass().getName().hashCode(); - BBQFlatIndexOptions() { - super(VectorIndexType.BBQ_FLAT); + BBQFlatIndexOptions(RescoreVector rescoreVector) { + super(VectorIndexType.BBQ_FLAT, rescoreVector); } @Override @@ -1877,6 +1943,9 @@ int doHashCode() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field("type", type); + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1890,6 +1959,41 @@ public void validateDimension(int dim) { } } + record RescoreVector(float oversample) implements ToXContentObject { + static final String NAME = "rescore_vector"; + static final String OVERSAMPLE = "oversample"; + + static RescoreVector fromIndexOptions(Map indexOptionsMap) { + Object rescoreVectorNode = indexOptionsMap.remove(NAME); + if (rescoreVectorNode == null) { + return null; + } + Map mappedNode = XContentMapValues.nodeMapValue(rescoreVectorNode, NAME); + Object oversampleNode = mappedNode.get(OVERSAMPLE); + if (oversampleNode == null) { + throw new IllegalArgumentException("Invalid rescore_vector value. Missing required field " + OVERSAMPLE); + } + return new RescoreVector((float) XContentMapValues.nodeDoubleValue(oversampleNode)); + } + + RescoreVector { + if (oversample < 1) { + throw new IllegalArgumentException("oversample must be greater than 1"); + } + if (oversample > 10) { + throw new IllegalArgumentException("oversample must be less than or equal to 10"); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + builder.field(OVERSAMPLE, oversample); + builder.endObject(); + return builder; + } + } + public static final TypeParser PARSER = new TypeParser( (n, c) -> new Builder(n, c.indexVersionCreated()), notInMultiFields(CONTENT_TYPE) @@ -2105,7 +2209,7 @@ private Query createKnnFloatQuery( float[] queryVector, int k, int numCands, - Float oversample, + Float queryOversample, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -2127,6 +2231,14 @@ && isNotUnitVector(squaredMagnitude)) { } int adjustedK = k; + // By default utilize the quantized oversample is configured + // allow the user provided at query time overwrite + Float oversample = queryOversample; + if (oversample == null + && indexOptions instanceof QuantizedIndexOptions quantizedIndexOptions + && quantizedIndexOptions.rescoreVector != null) { + oversample = quantizedIndexOptions.rescoreVector.oversample; + } boolean rescore = needsRescore(oversample); if (rescore) { // Will get k * oversample for rescoring, and get the top k @@ -2322,7 +2434,7 @@ public FieldMapper.Builder getMergeBuilder() { return new Builder(leafName(), indexCreatedVersion).init(this); } - private static IndexOptions parseIndexOptions(String fieldName, Object propNode) { + private static IndexOptions parseIndexOptions(String fieldName, Object propNode, IndexVersion indexVersion) { @SuppressWarnings("unchecked") Map indexOptionsMap = (Map) propNode; Object typeNode = indexOptionsMap.remove("type"); @@ -2335,7 +2447,7 @@ private static IndexOptions parseIndexOptions(String fieldName, Object propNode) throw new MapperParsingException("Unknown vector index options type [" + type + "] for field [" + fieldName + "]"); } VectorIndexType parsedType = vectorIndexType.get(); - return parsedType.parseIndexOptions(fieldName, indexOptionsMap); + return parsedType.parseIndexOptions(fieldName, indexOptionsMap, indexVersion); } /** diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 6e13faa99b4b5..27c64de80b43f 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -59,6 +59,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Set; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; @@ -883,6 +884,120 @@ protected void assertExistsQuery(MappedFieldType fieldType, Query query, LuceneD @Override public void testAggregatableConsistency() {} + public void testRescoreVectorForNonQuantized() { + for (String indexType : List.of("hnsw", "flat")) { + Exception e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", 1.5f)) + .endObject() + ) + ) + ); + e.getMessage().contains("Mapping definition for [field] has unsupported parameters:"); + } + } + + public void tesetRescoreVectorOldIndexVersion() { + IndexVersion incompatibleVersion = IndexVersionUtils.randomVersionBetween( + random(), + IndexVersions.V_8_0_0, + IndexVersionUtils.getPreviousVersion(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS) + ); + for (String indexType : List.of("int8_hnsw", "int8_flat", "int4_hnsw", "int4_flat", "bbq_hnsw", "bbq_flat")) { + expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + incompatibleVersion, + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", 1.5f)) + .endObject() + ) + ) + ); + } + } + + public void testInvalidRescoreVector() { + for (String indexType : List.of("int8_hnsw", "int8_flat", "int4_hnsw", "int4_flat", "bbq_hnsw", "bbq_flat")) { + Exception e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("foo", 1.5f)) + .endObject() + ) + ) + ); + e.getMessage().contains("Invalid rescore_vector value. Missing required field oversample"); + e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", "foo")) + .endObject() + ) + ) + ); + e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", 0.1f)) + .endObject() + ) + ) + ); + e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of()) + .endObject() + ) + ) + ); + e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", 10.1f)) + .endObject() + ) + ) + ); + } + } + public void testDims() { { Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 5c067cb2d0a27..e98038b7a0759 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -49,6 +49,10 @@ public DenseVectorFieldTypeTests() { this.indexed = randomBoolean(); } + private static DenseVectorFieldMapper.RescoreVector randomRescoreVector() { + return new DenseVectorFieldMapper.RescoreVector(randomFloatBetween(1.0F, 10.0F, false)); + } + private DenseVectorFieldMapper.IndexOptions randomIndexOptionsNonQuantized() { return randomFrom( new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), @@ -62,18 +66,30 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() { new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.Int4HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.FlatIndexOptions(), - new DenseVectorFieldMapper.Int8FlatIndexOptions(randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))), - new DenseVectorFieldMapper.Int4FlatIndexOptions(randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))), - new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), - new DenseVectorFieldMapper.BBQFlatIndexOptions() + new DenseVectorFieldMapper.Int8FlatIndexOptions( + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ), + new DenseVectorFieldMapper.Int4FlatIndexOptions( + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ), + new DenseVectorFieldMapper.BBQHnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ), + new DenseVectorFieldMapper.BBQFlatIndexOptions(randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())) ); } @@ -82,14 +98,20 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsHnswQuantized() { new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.Int4HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), - new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)) + new DenseVectorFieldMapper.BBQHnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ) ); } @@ -195,6 +217,9 @@ public void testCreateNestedKnnQuery() { queryVector[i] = randomFloat(); } Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, producer); + if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { + query = rescoreKnnVectorQuery.innerQuery(); + } assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); } { @@ -346,6 +371,9 @@ public void testCreateKnnQueryMaxDims() { queryVector[i] = randomFloat(); } Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null); + if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { + query = rescoreKnnVectorQuery.innerQuery(); + } assertThat(query, instanceOf(KnnFloatVectorQuery.class)); }