From d131713ca923aa286f230de7a54340cdcc0176cb Mon Sep 17 00:00:00 2001 From: Danyil Date: Wed, 11 Oct 2023 01:09:15 -0400 Subject: [PATCH 1/2] added another search function that takes a custom expansion factor ef for the search --- c/lib.cpp | 24 ++++++++++++++++++------ c/usearch.h | 4 ++++ include/usearch/index_punned_dense.hpp | 14 +++++++------- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/c/lib.cpp b/c/lib.cpp index b22c8f1e..0b33b715 100644 --- a/c/lib.cpp +++ b/c/lib.cpp @@ -90,13 +90,13 @@ bool get_(index_t* index, label_t label, void* vector, scalar_kind_t kind) { } #endif -search_result_t search_(index_t* index, void const* vector, scalar_kind_t kind, size_t n) { +search_result_t search_(index_t* index, void const* vector, scalar_kind_t kind, size_t n, size_t ef = 0) { switch (kind) { - case scalar_kind_t::f32_k: return index->search((f32_t const*)vector, n); - case scalar_kind_t::f64_k: return index->search((f64_t const*)vector, n); - case scalar_kind_t::f16_k: return index->search((f16_t const*)vector, n); - case scalar_kind_t::f8_k: return index->search((f8_bits_t const*)vector, n); - case scalar_kind_t::b1x8_k: return index->search((b1x8_t const*)vector, n); + case scalar_kind_t::f32_k: return index->search((f32_t const*)vector, n, ef); + case scalar_kind_t::f64_k: return index->search((f64_t const*)vector, n, ef); + case scalar_kind_t::f16_k: return index->search((f16_t const*)vector, n, ef); + case scalar_kind_t::f8_k: return index->search((f8_bits_t const*)vector, n, ef); + case scalar_kind_t::b1x8_k: return index->search((b1x8_t const*)vector, n, ef); default: return index->empty_search_result().failed("Unknown scalar kind!"); } } @@ -281,6 +281,18 @@ USEARCH_EXPORT size_t usearch_search( return result.dump_to(found_labels, found_distances); } +USEARCH_EXPORT size_t usearch_search_custom_ef( // + usearch_index_t index, void const* vector, usearch_scalar_kind_t kind, size_t results_limit, size_t ef, // + usearch_label_t* found_labels, usearch_distance_t* found_distances, usearch_error_t* error) { + search_result_t result = search_(reinterpret_cast(index), vector, to_native_scalar(kind), results_limit, ef); + if (!result) { + *error = result.error.what(); + return 0; + } + + return result.dump_to(found_labels, found_distances); +} + #if USEARCH_LOOKUP_LABEL USEARCH_EXPORT bool usearch_get( // usearch_index_t index, usearch_label_t label, // diff --git a/c/usearch.h b/c/usearch.h index c86a6c59..db857125 100644 --- a/c/usearch.h +++ b/c/usearch.h @@ -107,6 +107,10 @@ USEARCH_EXPORT size_t usearch_search( usearch_index_t, void const* query_vector, usearch_scalar_kind_t query_kind, size_t results_limit, // usearch_label_t* found_labels, usearch_distance_t* found_distances, usearch_error_t*); +USEARCH_EXPORT size_t usearch_search_custom_ef( // + usearch_index_t, void const* query_vector, usearch_scalar_kind_t query_kind, size_t results_limit, size_t ef, // + usearch_label_t* found_labels, usearch_distance_t* found_distances, usearch_error_t*); + USEARCH_EXPORT bool usearch_get( // usearch_index_t, usearch_label_t, // void* vector, usearch_scalar_kind_t vector_kind, usearch_error_t*); diff --git a/include/usearch/index_punned_dense.hpp b/include/usearch/index_punned_dense.hpp index 97ca40ab..632bd625 100644 --- a/include/usearch/index_punned_dense.hpp +++ b/include/usearch/index_punned_dense.hpp @@ -361,11 +361,11 @@ class index_punned_dense_gt { add_result_t add(label_t label, f32_t const* vector, add_config_t config) { return add_(label, vector, config, casts_.from_f32); } add_result_t add(label_t label, f64_t const* vector, add_config_t config) { return add_(label, vector, config, casts_.from_f64); } - search_result_t search(b1x8_t const* vector, std::size_t wanted) const { return search_(vector, wanted, casts_.from_b1x8); } - search_result_t search(f8_bits_t const* vector, std::size_t wanted) const { return search_(vector, wanted, casts_.from_f8); } - search_result_t search(f16_t const* vector, std::size_t wanted) const { return search_(vector, wanted, casts_.from_f16); } - search_result_t search(f32_t const* vector, std::size_t wanted) const { return search_(vector, wanted, casts_.from_f32); } - search_result_t search(f64_t const* vector, std::size_t wanted) const { return search_(vector, wanted, casts_.from_f64); } + search_result_t search(b1x8_t const* vector, std::size_t wanted, std::size_t ef) const { return search_(vector, wanted, ef, casts_.from_b1x8); } + search_result_t search(f8_bits_t const* vector, std::size_t wanted, std::size_t ef) const { return search_(vector, wanted, ef, casts_.from_f8); } + search_result_t search(f16_t const* vector, std::size_t wanted, std::size_t ef) const { return search_(vector, wanted, ef, casts_.from_f16); } + search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t ef) const { return search_(vector, wanted, ef, casts_.from_f32); } + search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t ef) const { return search_(vector, wanted, ef, casts_.from_f64); } search_result_t search(b1x8_t const* vector, std::size_t wanted, search_config_t config) const { return search_(vector, wanted, config, casts_.from_b1x8); } search_result_t search(f8_bits_t const* vector, std::size_t wanted, search_config_t config) const { return search_(vector, wanted, config, casts_.from_f8); } @@ -915,11 +915,11 @@ class index_punned_dense_gt { template search_result_t search_( // - scalar_at const* vector, std::size_t wanted, // + scalar_at const* vector, std::size_t wanted, std::size_t ef, // cast_t const& cast) const { thread_lock_t lock = thread_lock_(); search_config_t search_config; - search_config.expansion = expansion_search_; + search_config.expansion = (ef != 0) ? ef : expansion_search_; search_config.thread = lock.thread_id; return search_(vector, wanted, search_config, cast); } From 6983068f48c6f0ef76ee9f21e990ff7a91352203 Mon Sep 17 00:00:00 2001 From: Danyil Date: Sun, 15 Oct 2023 22:40:18 -0400 Subject: [PATCH 2/2] refactored code to make ef passing to usearch_search more clear --- c/lib.cpp | 59 +++++++++++++++----------- c/usearch.h | 5 +-- include/usearch/index_punned_dense.hpp | 31 +++++++++++--- 3 files changed, 60 insertions(+), 35 deletions(-) diff --git a/c/lib.cpp b/c/lib.cpp index 0b33b715..5fd89257 100644 --- a/c/lib.cpp +++ b/c/lib.cpp @@ -90,13 +90,24 @@ bool get_(index_t* index, label_t label, void* vector, scalar_kind_t kind) { } #endif +search_result_t custom_ef_search_(index_t* index, void const* vector, scalar_kind_t kind, size_t n, size_t ef = 0) { + switch (kind) { + case scalar_kind_t::f32_k: return index->custom_ef_search((f32_t const*)vector, n, ef); + case scalar_kind_t::f64_k: return index->custom_ef_search((f64_t const*)vector, n, ef); + case scalar_kind_t::f16_k: return index->custom_ef_search((f16_t const*)vector, n, ef); + case scalar_kind_t::f8_k: return index->custom_ef_search((f8_bits_t const*)vector, n, ef); + case scalar_kind_t::b1x8_k: return index->custom_ef_search((b1x8_t const*)vector, n, ef); + default: return index->empty_search_result().failed("Unknown scalar kind!"); + } +} + search_result_t search_(index_t* index, void const* vector, scalar_kind_t kind, size_t n, size_t ef = 0) { switch (kind) { - case scalar_kind_t::f32_k: return index->search((f32_t const*)vector, n, ef); - case scalar_kind_t::f64_k: return index->search((f64_t const*)vector, n, ef); - case scalar_kind_t::f16_k: return index->search((f16_t const*)vector, n, ef); - case scalar_kind_t::f8_k: return index->search((f8_bits_t const*)vector, n, ef); - case scalar_kind_t::b1x8_k: return index->search((b1x8_t const*)vector, n, ef); + case scalar_kind_t::f32_k: return index->search((f32_t const*)vector, n); + case scalar_kind_t::f64_k: return index->search((f64_t const*)vector, n); + case scalar_kind_t::f16_k: return index->search((f16_t const*)vector, n); + case scalar_kind_t::f8_k: return index->search((f8_bits_t const*)vector, n); + case scalar_kind_t::b1x8_k: return index->search((b1x8_t const*)vector, n); default: return index->empty_search_result().failed("Unknown scalar kind!"); } } @@ -269,28 +280,28 @@ USEARCH_EXPORT bool usearch_contains(usearch_index_t index, usearch_label_t labe } #endif -USEARCH_EXPORT size_t usearch_search( // - usearch_index_t index, void const* vector, usearch_scalar_kind_t kind, size_t results_limit, // - usearch_label_t* found_labels, usearch_distance_t* found_distances, usearch_error_t* error) { - search_result_t result = search_(reinterpret_cast(index), vector, to_native_scalar(kind), results_limit); - if (!result) { - *error = result.error.what(); - return 0; - } - - return result.dump_to(found_labels, found_distances); -} - -USEARCH_EXPORT size_t usearch_search_custom_ef( // +USEARCH_EXPORT size_t usearch_search( // usearch_index_t index, void const* vector, usearch_scalar_kind_t kind, size_t results_limit, size_t ef, // usearch_label_t* found_labels, usearch_distance_t* found_distances, usearch_error_t* error) { - search_result_t result = search_(reinterpret_cast(index), vector, to_native_scalar(kind), results_limit, ef); - if (!result) { - *error = result.error.what(); - return 0; + if (ef == USEARCH_SEARCH_EF_INVALID_VALUE) { + // use the ef already stored in the index for the search + search_result_t result = + search_(reinterpret_cast(index), vector, to_native_scalar(kind), results_limit); + if (!result) { + *error = result.error.what(); + return 0; + } + return result.dump_to(found_labels, found_distances); + } else { + // use the ef passed here during the search + search_result_t result = + custom_ef_search_(reinterpret_cast(index), vector, to_native_scalar(kind), results_limit, ef); + if (!result) { + *error = result.error.what(); + return 0; + } + return result.dump_to(found_labels, found_distances); } - - return result.dump_to(found_labels, found_distances); } #if USEARCH_LOOKUP_LABEL diff --git a/c/usearch.h b/c/usearch.h index db857125..c5cc9a2c 100644 --- a/c/usearch.h +++ b/c/usearch.h @@ -103,11 +103,8 @@ USEARCH_EXPORT bool usearch_contains(usearch_index_t, usearch_label_t, usearch_e * @brief Performs k-Approximate Nearest Neighbors Search. * @return Number of found matches. */ +#define USEARCH_SEARCH_EF_INVALID_VALUE 0 USEARCH_EXPORT size_t usearch_search( // - usearch_index_t, void const* query_vector, usearch_scalar_kind_t query_kind, size_t results_limit, // - usearch_label_t* found_labels, usearch_distance_t* found_distances, usearch_error_t*); - -USEARCH_EXPORT size_t usearch_search_custom_ef( // usearch_index_t, void const* query_vector, usearch_scalar_kind_t query_kind, size_t results_limit, size_t ef, // usearch_label_t* found_labels, usearch_distance_t* found_distances, usearch_error_t*); diff --git a/include/usearch/index_punned_dense.hpp b/include/usearch/index_punned_dense.hpp index 632bd625..f87c6e3b 100644 --- a/include/usearch/index_punned_dense.hpp +++ b/include/usearch/index_punned_dense.hpp @@ -361,11 +361,17 @@ class index_punned_dense_gt { add_result_t add(label_t label, f32_t const* vector, add_config_t config) { return add_(label, vector, config, casts_.from_f32); } add_result_t add(label_t label, f64_t const* vector, add_config_t config) { return add_(label, vector, config, casts_.from_f64); } - search_result_t search(b1x8_t const* vector, std::size_t wanted, std::size_t ef) const { return search_(vector, wanted, ef, casts_.from_b1x8); } - search_result_t search(f8_bits_t const* vector, std::size_t wanted, std::size_t ef) const { return search_(vector, wanted, ef, casts_.from_f8); } - search_result_t search(f16_t const* vector, std::size_t wanted, std::size_t ef) const { return search_(vector, wanted, ef, casts_.from_f16); } - search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t ef) const { return search_(vector, wanted, ef, casts_.from_f32); } - search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t ef) const { return search_(vector, wanted, ef, casts_.from_f64); } + search_result_t custom_ef_search(b1x8_t const* vector, std::size_t wanted, std::size_t ef) const { return custom_ef_search_(vector, wanted, ef, casts_.from_b1x8); } + search_result_t custom_ef_search(f8_bits_t const* vector, std::size_t wanted, std::size_t ef) const { return custom_ef_search_(vector, wanted, ef, casts_.from_f8); } + search_result_t custom_ef_search(f16_t const* vector, std::size_t wanted, std::size_t ef) const { return custom_ef_search_(vector, wanted, ef, casts_.from_f16); } + search_result_t custom_ef_search(f32_t const* vector, std::size_t wanted, std::size_t ef) const { return custom_ef_search_(vector, wanted, ef, casts_.from_f32); } + search_result_t custom_ef_search(f64_t const* vector, std::size_t wanted, std::size_t ef) const { return custom_ef_search_(vector, wanted, ef, casts_.from_f64); } + + search_result_t search(b1x8_t const* vector, std::size_t wanted) const { return search_(vector, wanted, casts_.from_b1x8); } + search_result_t search(f8_bits_t const* vector, std::size_t wanted) const { return search_(vector, wanted, casts_.from_f8); } + search_result_t search(f16_t const* vector, std::size_t wanted) const { return search_(vector, wanted, casts_.from_f16); } + search_result_t search(f32_t const* vector, std::size_t wanted) const { return search_(vector, wanted, casts_.from_f32); } + search_result_t search(f64_t const* vector, std::size_t wanted) const { return search_(vector, wanted, casts_.from_f64); } search_result_t search(b1x8_t const* vector, std::size_t wanted, search_config_t config) const { return search_(vector, wanted, config, casts_.from_b1x8); } search_result_t search(f8_bits_t const* vector, std::size_t wanted, search_config_t config) const { return search_(vector, wanted, config, casts_.from_f8); } @@ -914,12 +920,23 @@ class index_punned_dense_gt { } template - search_result_t search_( // + search_result_t custom_ef_search_( // scalar_at const* vector, std::size_t wanted, std::size_t ef, // cast_t const& cast) const { thread_lock_t lock = thread_lock_(); search_config_t search_config; - search_config.expansion = (ef != 0) ? ef : expansion_search_; + search_config.expansion = ef; + search_config.thread = lock.thread_id; + return search_(vector, wanted, search_config, cast); + } + + template + search_result_t search_( // + scalar_at const* vector, std::size_t wanted, // + cast_t const& cast) const { + thread_lock_t lock = thread_lock_(); + search_config_t search_config; + search_config.expansion = expansion_search_; search_config.thread = lock.thread_id; return search_(vector, wanted, search_config, cast); }