Skip to content

Commit a005084

Browse files
committed
Add: Exact search shortcut
Closes #176
1 parent 05e908f commit a005084

File tree

4 files changed

+286
-24
lines changed

4 files changed

+286
-24
lines changed

include/usearch/index.hpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,21 @@ template <typename allocator_at = std::allocator<byte_t>> class visits_bitset_gt
468468
}
469469

470470
#endif
471+
472+
class lock_t {
473+
visits_bitset_gt& bitset_;
474+
std::size_t bit_offset_;
475+
476+
public:
477+
inline ~lock_t() noexcept { bitset_.atomic_reset(bit_offset_); }
478+
inline lock_t(visits_bitset_gt& bitset, std::size_t bit_offset) noexcept
479+
: bitset_(bitset), bit_offset_(bit_offset) {
480+
while (bitset_.atomic_set(bit_offset_))
481+
;
482+
}
483+
};
484+
485+
inline lock_t lock(std::size_t i) noexcept { return {*this, i}; }
471486
};
472487

473488
using visits_bitset_t = visits_bitset_gt<>;
@@ -2010,7 +2025,7 @@ class index_gt {
20102025
std::memmove(distances + offset + 1, distances + offset, count_worse * sizeof(distance_t));
20112026
keys[offset] = result.member.key;
20122027
distances[offset] = result.distance;
2013-
merged_count = (std::min)(merged_count + 1u, max_count);
2028+
merged_count += merged_count != max_count;
20142029
}
20152030
return merged_count;
20162031
}

python/lib.cpp

+134-6
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ static void search_typed( //
298298
if (!threads)
299299
threads = std::thread::hardware_concurrency();
300300

301-
std::vector<std::mutex> vectors_mutexes(static_cast<std::size_t>(vectors_count));
301+
std::vector<std::mutex> query_mutexes(static_cast<std::size_t>(vectors_count));
302302
executor_default_t{threads}.execute_bulk(indexes.shards_.size(), [&](std::size_t, std::size_t task_idx) {
303303
dense_index_py_t& index = *indexes.shards_[task_idx].get();
304304

@@ -318,7 +318,7 @@ static void search_typed( //
318318
dense_search_result_t result = index.search(vector, wanted, config);
319319
result.error.raise();
320320
{
321-
std::unique_lock<std::mutex> lock(vectors_mutexes[vector_idx]);
321+
std::unique_lock<std::mutex> lock(query_mutexes[vector_idx]);
322322
counts_py1d(vector_idx) = static_cast<Py_ssize_t>(result.merge_into( //
323323
&keys_py2d(vector_idx, 0), //
324324
&distances_py2d(vector_idx, 0), //
@@ -348,7 +348,7 @@ static py::tuple search_many_in_index( //
348348
index_at& index, py::buffer vectors, std::size_t wanted, bool exact, std::size_t threads) {
349349

350350
if (wanted == 0)
351-
return py::tuple(3);
351+
return py::tuple(5);
352352

353353
if (index.limits().threads_search < threads)
354354
throw std::invalid_argument("Can't use that many threads!");
@@ -388,6 +388,123 @@ static py::tuple search_many_in_index( //
388388
return results;
389389
}
390390

391+
template <typename scalar_at>
392+
static void search_typed_brute_force( //
393+
py::buffer_info& dataset_info, py::buffer_info& queries_info, //
394+
std::size_t wanted, std::size_t threads, metric_t const& metric, //
395+
py::array_t<key_t>& keys_py, py::array_t<distance_t>& distances_py, py::array_t<Py_ssize_t>& counts_py) {
396+
397+
auto keys_py2d = keys_py.template mutable_unchecked<2>();
398+
auto distances_py2d = distances_py.template mutable_unchecked<2>();
399+
auto counts_py1d = counts_py.template mutable_unchecked<1>();
400+
401+
std::size_t dataset_count = static_cast<std::size_t>(dataset_info.shape[0]);
402+
std::size_t queries_count = static_cast<std::size_t>(queries_info.shape[0]);
403+
std::size_t dimensions = static_cast<std::size_t>(dataset_info.shape[1]);
404+
405+
byte_t const* dataset_data = reinterpret_cast<byte_t const*>(dataset_info.ptr);
406+
byte_t const* queries_data = reinterpret_cast<byte_t const*>(queries_info.ptr);
407+
for (std::size_t query_idx = 0; query_idx != queries_count; ++query_idx)
408+
counts_py1d(query_idx) = 0;
409+
410+
if (!threads)
411+
threads = std::thread::hardware_concurrency();
412+
413+
std::size_t tasks_count = static_cast<std::size_t>(dataset_count * queries_count);
414+
visits_bitset_t query_mutexes(static_cast<std::size_t>(queries_count));
415+
if (!query_mutexes)
416+
throw std::bad_alloc();
417+
418+
executor_default_t{threads}.execute_bulk(tasks_count, [&](std::size_t, std::size_t task_idx) {
419+
//
420+
std::size_t dataset_idx = task_idx / queries_count;
421+
std::size_t query_idx = task_idx % queries_count;
422+
423+
byte_t const* dataset = dataset_data + dataset_idx * dataset_info.strides[0];
424+
byte_t const* query = queries_data + query_idx * queries_info.strides[0];
425+
distance_t distance = metric(dataset, query);
426+
427+
{
428+
auto lock = query_mutexes.lock(query_idx);
429+
key_t* keys = &keys_py2d(query_idx, 0);
430+
distance_t* distances = &distances_py2d(query_idx, 0);
431+
std::size_t& matches = reinterpret_cast<std::size_t&>(counts_py1d(query_idx));
432+
if (matches == wanted)
433+
if (distances[wanted - 1] <= distance)
434+
return;
435+
436+
std::size_t offset = std::lower_bound(distances, distances + matches, distance) - distances;
437+
438+
std::size_t count_worse = matches - offset - (wanted == matches);
439+
std::memmove(keys + offset + 1, keys + offset, count_worse * sizeof(key_t));
440+
std::memmove(distances + offset + 1, distances + offset, count_worse * sizeof(distance_t));
441+
keys[offset] = static_cast<key_t>(dataset_idx);
442+
distances[offset] = distance;
443+
matches += matches != wanted;
444+
}
445+
446+
if (PyErr_CheckSignals() != 0)
447+
throw py::error_already_set();
448+
});
449+
}
450+
451+
static py::tuple search_many_brute_force( //
452+
py::buffer dataset, py::buffer queries, //
453+
std::size_t wanted, std::size_t threads, //
454+
metric_kind_t metric_kind, //
455+
metric_signature_t metric_signature, //
456+
std::uintptr_t metric_uintptr) {
457+
458+
if (wanted == 0)
459+
return py::tuple(5);
460+
461+
py::buffer_info dataset_info = dataset.request();
462+
py::buffer_info queries_info = queries.request();
463+
if (dataset_info.ndim != 2 || queries_info.ndim != 2)
464+
throw std::invalid_argument("Expects a matrix of dataset to add!");
465+
466+
Py_ssize_t dataset_count = dataset_info.shape[0];
467+
Py_ssize_t dataset_dimensions = dataset_info.shape[1];
468+
Py_ssize_t queries_count = queries_info.shape[0];
469+
Py_ssize_t queries_dimensions = queries_info.shape[1];
470+
if (dataset_dimensions != queries_dimensions)
471+
throw std::invalid_argument("The number of vector dimensions doesn't match!");
472+
473+
scalar_kind_t dataset_kind = numpy_string_to_kind(dataset_info.format);
474+
scalar_kind_t queries_kind = numpy_string_to_kind(queries_info.format);
475+
if (dataset_kind != queries_kind)
476+
throw std::invalid_argument("The types of vectors don't match!");
477+
478+
py::array_t<key_t> keys_py({dataset_count, static_cast<Py_ssize_t>(wanted)});
479+
py::array_t<distance_t> distances_py({dataset_count, static_cast<Py_ssize_t>(wanted)});
480+
py::array_t<Py_ssize_t> counts_py(dataset_count);
481+
482+
std::size_t dimensions = static_cast<std::size_t>(queries_dimensions);
483+
metric_t metric = //
484+
metric_uintptr //
485+
? udf(metric_kind, metric_signature, metric_uintptr, queries_kind, dimensions)
486+
: metric_t(dimensions, metric_kind, queries_kind);
487+
488+
// clang-format off
489+
switch (dataset_kind) {
490+
case scalar_kind_t::b1x8_k: search_typed_brute_force<b1x8_t>(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break;
491+
case scalar_kind_t::i8_k: search_typed_brute_force<i8_bits_t>(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break;
492+
case scalar_kind_t::f16_k: search_typed_brute_force<f16_t>(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break;
493+
case scalar_kind_t::f32_k: search_typed_brute_force<f32_t>(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break;
494+
case scalar_kind_t::f64_k: search_typed_brute_force<f64_t>(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break;
495+
default: throw std::invalid_argument("Incompatible vector types: " + dataset_info.format);
496+
}
497+
// clang-format on
498+
499+
py::tuple results(5);
500+
results[0] = keys_py;
501+
results[1] = distances_py;
502+
results[2] = counts_py;
503+
results[3] = 0;
504+
results[4] = static_cast<std::size_t>(dataset_count * queries_count);
505+
return results;
506+
}
507+
391508
static std::unordered_map<key_t, key_t> join_index( //
392509
dense_index_py_t const& a, dense_index_py_t const& b, //
393510
std::size_t max_proposals, bool exact) {
@@ -505,7 +622,7 @@ PYBIND11_MODULE(compiled, m) {
505622
py::enum_<metric_kind_t>(m, "MetricKind")
506623
.value("Unknown", metric_kind_t::unknown_k)
507624

508-
.value("IP", metric_kind_t::ip_k)
625+
.value("IP", metric_kind_t::cos_k)
509626
.value("Cos", metric_kind_t::cos_k)
510627
.value("L2sq", metric_kind_t::l2sq_k)
511628

@@ -517,7 +634,7 @@ PYBIND11_MODULE(compiled, m) {
517634
.value("Sorensen", metric_kind_t::sorensen_k)
518635

519636
.value("Cosine", metric_kind_t::cos_k)
520-
.value("InnerProduct", metric_kind_t::ip_k);
637+
.value("InnerProduct", metric_kind_t::cos_k);
521638

522639
py::enum_<scalar_kind_t>(m, "ScalarKind")
523640
.value("Unknown", scalar_kind_t::unknown_k)
@@ -562,13 +679,24 @@ PYBIND11_MODULE(compiled, m) {
562679
return result;
563680
});
564681

682+
m.def("exact_search", &search_many_brute_force, //
683+
py::arg("dataset"), //
684+
py::arg("queries"), //
685+
py::arg("count") = 10, //
686+
py::kw_only(), //
687+
py::arg("threads") = 0, //
688+
py::arg("metric_kind") = metric_kind_t::cos_k, //
689+
py::arg("metric_signature") = metric_signature_t::array_array_k, //
690+
py::arg("metric_pointer") = 0 //
691+
);
692+
565693
auto i = py::class_<dense_index_py_t, std::shared_ptr<dense_index_py_t>>(m, "Index");
566694

567695
i.def(py::init(&make_index), //
568696
py::kw_only(), //
569697
py::arg("ndim") = 0, //
570698
py::arg("dtype") = scalar_kind_t::f32_k, //
571-
py::arg("metric_kind") = metric_kind_t::ip_k, //
699+
py::arg("metric_kind") = metric_kind_t::cos_k, //
572700
py::arg("connectivity") = default_connectivity(), //
573701
py::arg("expansion_add") = default_expansion_add(), //
574702
py::arg("expansion_search") = default_expansion_search(), //

python/scripts/test.py

+22
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from usearch.io import load_matrix, save_matrix
77
from usearch.eval import random_vectors
8+
from usearch.index import search
89

910
from usearch.index import (
1011
Index,
@@ -81,6 +82,27 @@ def test_serializing_ibin_matrix(rows: int, cols: int):
8182
os.remove(temporary_filename + ".ibin")
8283

8384

85+
@pytest.mark.parametrize("rows", batch_sizes)
86+
@pytest.mark.parametrize("cols", dimensions)
87+
def test_exact_search(rows: int, cols: int):
88+
"""
89+
Test exact search.
90+
91+
:param int rows: The number of rows in the matrix.
92+
:param int cols: The number of columns in the matrix.
93+
"""
94+
original = np.random.rand(rows, cols)
95+
matches: BatchMatches = search(original, original, 10, exact=True)
96+
top_matches = (
97+
[int(m.keys[0]) for m in matches] if rows > 1 else int(matches.keys[0])
98+
)
99+
assert np.all(top_matches == np.arange(rows))
100+
101+
matches: Matches = search(original, original[0], 10, exact=True)
102+
top_match = int(matches.keys[0])
103+
assert top_match == 0
104+
105+
84106
@pytest.mark.parametrize("ndim", dimensions)
85107
@pytest.mark.parametrize("metric", continuous_metrics)
86108
@pytest.mark.parametrize("index_type", index_types)

0 commit comments

Comments
 (0)