Skip to content

Commit

Permalink
Add new search function to support custom expansion factor (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
therealdarkknight authored Oct 23, 2023
1 parent 2115739 commit 80450b9
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 10 deletions.
41 changes: 32 additions & 9 deletions c/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,18 @@ bool get_(index_t* index, label_t label, void* vector, scalar_kind_t kind) {
}
#endif

search_result_t search_(index_t* index, void const* vector, scalar_kind_t kind, size_t n) {
search_result_t custom_ef_search_(index_t* index, void const* vector, scalar_kind_t kind, size_t n, size_t ef = 0) {
switch (kind) {
case scalar_kind_t::f32_k: return index->custom_ef_search((f32_t const*)vector, n, ef);
case scalar_kind_t::f64_k: return index->custom_ef_search((f64_t const*)vector, n, ef);
case scalar_kind_t::f16_k: return index->custom_ef_search((f16_t const*)vector, n, ef);
case scalar_kind_t::f8_k: return index->custom_ef_search((f8_bits_t const*)vector, n, ef);
case scalar_kind_t::b1x8_k: return index->custom_ef_search((b1x8_t const*)vector, n, ef);
default: return index->empty_search_result().failed("Unknown scalar kind!");
}
}

search_result_t search_(index_t* index, void const* vector, scalar_kind_t kind, size_t n, size_t ef = 0) {
switch (kind) {
case scalar_kind_t::f32_k: return index->search((f32_t const*)vector, n);
case scalar_kind_t::f64_k: return index->search((f64_t const*)vector, n);
Expand Down Expand Up @@ -269,16 +280,28 @@ USEARCH_EXPORT bool usearch_contains(usearch_index_t index, usearch_label_t labe
}
#endif

USEARCH_EXPORT size_t usearch_search( //
usearch_index_t index, void const* vector, usearch_scalar_kind_t kind, size_t results_limit, //
USEARCH_EXPORT size_t usearch_search( //
usearch_index_t index, void const* vector, usearch_scalar_kind_t kind, size_t results_limit, size_t ef, //
usearch_label_t* found_labels, usearch_distance_t* found_distances, usearch_error_t* error) {
search_result_t result = search_(reinterpret_cast<index_t*>(index), vector, to_native_scalar(kind), results_limit);
if (!result) {
*error = result.error.what();
return 0;
if (ef == USEARCH_SEARCH_EF_INVALID_VALUE) {
// use the ef already stored in the index for the search
search_result_t result =
search_(reinterpret_cast<index_t*>(index), vector, to_native_scalar(kind), results_limit);
if (!result) {
*error = result.error.what();
return 0;
}
return result.dump_to(found_labels, found_distances);
} else {
// use the ef passed here during the search
search_result_t result =
custom_ef_search_(reinterpret_cast<index_t*>(index), vector, to_native_scalar(kind), results_limit, ef);
if (!result) {
*error = result.error.what();
return 0;
}
return result.dump_to(found_labels, found_distances);
}

return result.dump_to(found_labels, found_distances);
}

#if USEARCH_LOOKUP_LABEL
Expand Down
3 changes: 2 additions & 1 deletion c/usearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ USEARCH_EXPORT bool usearch_contains(usearch_index_t, usearch_label_t, usearch_e
* @brief Performs k-Approximate Nearest Neighbors Search.
* @return Number of found matches.
*/
#define USEARCH_SEARCH_EF_INVALID_VALUE 0
USEARCH_EXPORT size_t usearch_search( //
usearch_index_t, void const* query_vector, usearch_scalar_kind_t query_kind, size_t results_limit, //
usearch_index_t, void const* query_vector, usearch_scalar_kind_t query_kind, size_t results_limit, size_t ef, //
usearch_label_t* found_labels, usearch_distance_t* found_distances, usearch_error_t*);

USEARCH_EXPORT bool usearch_get( //
Expand Down
17 changes: 17 additions & 0 deletions include/usearch/index_punned_dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,12 @@ class index_punned_dense_gt {
add_result_t add(label_t label, f32_t const* vector, add_config_t config) { return add_(label, vector, config, casts_.from_f32); }
add_result_t add(label_t label, f64_t const* vector, add_config_t config) { return add_(label, vector, config, casts_.from_f64); }

search_result_t custom_ef_search(b1x8_t const* vector, std::size_t wanted, std::size_t ef) const { return custom_ef_search_(vector, wanted, ef, casts_.from_b1x8); }
search_result_t custom_ef_search(f8_bits_t const* vector, std::size_t wanted, std::size_t ef) const { return custom_ef_search_(vector, wanted, ef, casts_.from_f8); }
search_result_t custom_ef_search(f16_t const* vector, std::size_t wanted, std::size_t ef) const { return custom_ef_search_(vector, wanted, ef, casts_.from_f16); }
search_result_t custom_ef_search(f32_t const* vector, std::size_t wanted, std::size_t ef) const { return custom_ef_search_(vector, wanted, ef, casts_.from_f32); }
search_result_t custom_ef_search(f64_t const* vector, std::size_t wanted, std::size_t ef) const { return custom_ef_search_(vector, wanted, ef, casts_.from_f64); }

search_result_t search(b1x8_t const* vector, std::size_t wanted) const { return search_(vector, wanted, casts_.from_b1x8); }
search_result_t search(f8_bits_t const* vector, std::size_t wanted) const { return search_(vector, wanted, casts_.from_f8); }
search_result_t search(f16_t const* vector, std::size_t wanted) const { return search_(vector, wanted, casts_.from_f16); }
Expand Down Expand Up @@ -913,6 +919,17 @@ class index_punned_dense_gt {
return add_(label, vector, add_config, cast, level, tape);
}

template <typename scalar_at>
search_result_t custom_ef_search_( //
scalar_at const* vector, std::size_t wanted, std::size_t ef, //
cast_t const& cast) const {
thread_lock_t lock = thread_lock_();
search_config_t search_config;
search_config.expansion = ef;
search_config.thread = lock.thread_id;
return search_(vector, wanted, search_config, cast);
}

template <typename scalar_at>
search_result_t search_( //
scalar_at const* vector, std::size_t wanted, //
Expand Down

0 comments on commit 80450b9

Please sign in to comment.