From c5b7ccd518f5e1f7aa4bdceef5a335cdac76b6b7 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Tue, 25 Jul 2023 13:55:33 +0000 Subject: [PATCH] Add: Multi-`Index` lookups --- include/usearch/index.hpp | 21 +++ python/lib.cpp | 110 ++++++++++++--- python/scripts/test.py | 15 ++ python/usearch/index.py | 285 ++++++++++++++++++++++++-------------- 4 files changed, 308 insertions(+), 123 deletions(-) diff --git a/include/usearch/index.hpp b/include/usearch/index.hpp index fae0af91..9b8ef47f 100644 --- a/include/usearch/index.hpp +++ b/include/usearch/index.hpp @@ -2083,6 +2083,27 @@ class index_gt { node_t node = index_.node_with_id_(candidate.id); return {member_cref_t{node.label(), node.vector_view(), candidate.id}, candidate.distance}; } + inline std::size_t merge_into( // + label_t* labels, distance_t* distances, // + std::size_t old_count, std::size_t max_count) const noexcept { + + std::size_t merged_count = old_count; + for (std::size_t i = 0; i != count; ++i) { + match_t result = operator[](i); + auto merged_end = distances + merged_count; + auto offset = std::lower_bound(distances, merged_end, result.distance) - distances; + if (offset == max_count) + continue; + + std::size_t count_worse = merged_count - offset - (max_count == merged_count); + std::memmove(labels + offset + 1, labels + offset, count_worse * sizeof(label_t)); + std::memmove(distances + offset + 1, distances + offset, count_worse * sizeof(distance_t)); + labels[merged_count] = result.member.label; + distances[merged_count] = result.distance; + merged_count += 1; + } + return merged_count; + } inline std::size_t dump_to(label_t* labels, distance_t* distances) const noexcept { for (std::size_t i = 0; i != count; ++i) { match_t result = operator[](i); diff --git a/python/lib.cpp b/python/lib.cpp index f22bf849..a1915208 100644 --- a/python/lib.cpp +++ b/python/lib.cpp @@ -59,6 +59,26 @@ struct dense_index_py_t : public dense_index_t { dense_index_py_t(native_t&& base) : native_t(std::move(base)) {} }; +struct dense_indexes_py_t { + std::vector> shards_; + + void add(std::shared_ptr shard) { shards_.push_back(shard); } + std::size_t scalar_words() const noexcept { return shards_.empty() ? 0 : shards_[0]->scalar_words(); } + index_limits_t limits() const noexcept { return {size(), std::numeric_limits::max()}; } + + std::size_t size() const noexcept { + std::size_t result = 0; + for (auto const& shard : shards_) + result += shard->size(); + return result; + } + + void reserve(index_limits_t) { + for (auto const& shard : shards_) + shard->reserve({shard->size(), 1}); + } +}; + using set_member_t = std::uint32_t; using set_view_t = span_gt; using sparse_index_t = index_gt, label_t, id_t>; @@ -194,8 +214,9 @@ static void add_typed_to_index( // }); } -static void add_many_to_index( // - dense_index_py_t& index, py::buffer labels, py::buffer vectors, // +template +static void add_many_to_index( // + index_at& index, py::buffer labels, py::buffer vectors, // bool copy, std::size_t threads) { py::buffer_info labels_info = labels.request(); @@ -267,6 +288,49 @@ static void search_typed( // }); } +template +static void search_typed( // + dense_indexes_py_t const& indexes, py::buffer_info& vectors_info, // + std::size_t wanted, bool exact, std::size_t threads, // + py::array_t& labels_py, py::array_t& distances_py, py::array_t& counts_py) { + + auto labels_py2d = labels_py.template mutable_unchecked<2>(); + auto distances_py2d = distances_py.template mutable_unchecked<2>(); + auto counts_py1d = counts_py.template mutable_unchecked<1>(); + + Py_ssize_t vectors_count = vectors_info.shape[0]; + char const* vectors_data = reinterpret_cast(vectors_info.ptr); + for (std::size_t vector_idx = 0; vector_idx != static_cast(vectors_count); ++vector_idx) + counts_py1d(vector_idx) = 0; + + if (!threads) + threads = std::thread::hardware_concurrency(); + + std::vector vectors_mutexes(static_cast(vectors_count)); + executor_default_t{threads}.execute_bulk(indexes.size(), [&](std::size_t, std::size_t task_idx) { + dense_index_py_t const& index = *indexes.shards_[task_idx].get(); + + search_config_t config; + config.thread = 0; + config.exact = exact; + for (std::size_t vector_idx = 0; vector_idx != static_cast(vectors_count); ++vector_idx) { + scalar_at const* vector = (scalar_at const*)(vectors_data + vector_idx * vectors_info.strides[0]); + dense_search_result_t result = index.search(vector, wanted, config); + result.error.raise(); + { + std::unique_lock lock(vectors_mutexes[vector_idx]); + counts_py1d(vector_idx) = static_cast(result.merge_into( // + &labels_py2d(vector_idx, 0), // + &distances_py2d(vector_idx, 0), // + static_cast(counts_py1d(vector_idx)), // + wanted)); + } + if (PyErr_CheckSignals() != 0) + throw py::error_already_set(); + } + }); +} + /** * @param vectors Matrix of vectors to search for. * @param wanted Number of matches per request. @@ -276,8 +340,9 @@ static void search_typed( // * 2. matrix of distances, * 3. array with match counts. */ +template static py::tuple search_many_in_index( // - dense_index_py_t& index, py::buffer vectors, std::size_t wanted, bool exact, std::size_t threads) { + index_at& index, py::buffer vectors, std::size_t wanted, bool exact, std::size_t threads) { if (wanted == 0) return py::tuple(3); @@ -560,7 +625,7 @@ PYBIND11_MODULE(compiled, m) { h.def_readonly("bytes_for_vectors", &file_head_result_t::bytes_for_vectors); h.def_readonly("bytes_checksum", &file_head_result_t::bytes_checksum); - auto i = py::class_(m, "Index"); + auto i = py::class_>(m, "Index"); i.def(py::init(&make_index), // py::kw_only(), // @@ -575,21 +640,21 @@ PYBIND11_MODULE(compiled, m) { py::arg("tune") = false // ); - i.def( // - "add", &add_many_to_index, // - py::arg("labels"), // - py::arg("vectors"), // - py::kw_only(), // - py::arg("copy") = true, // - py::arg("threads") = 0 // + i.def( // + "add", &add_many_to_index, // + py::arg("labels"), // + py::arg("vectors"), // + py::kw_only(), // + py::arg("copy") = true, // + py::arg("threads") = 0 // ); - i.def( // - "search", &search_many_in_index, // - py::arg("query"), // - py::arg("count") = 10, // - py::arg("exact") = false, // - py::arg("threads") = 0 // + i.def( // + "search", &search_many_in_index, // + py::arg("query"), // + py::arg("count") = 10, // + py::arg("exact") = false, // + py::arg("threads") = 0 // ); i.def( @@ -677,6 +742,17 @@ PYBIND11_MODULE(compiled, m) { i.def_property_readonly("levels_stats", &compute_stats); i.def("level_stats", &compute_level_stats, py::arg("level")); + auto is = py::class_(m, "Indexes"); + is.def(py::init()); + is.def("add", &dense_indexes_py_t::add); + is.def( // + "search", &search_many_in_index, // + py::arg("query"), // + py::arg("count") = 10, // + py::arg("exact") = false, // + py::arg("threads") = 0 // + ); + auto si = py::class_(m, "SparseIndex"); si.def( // diff --git a/python/scripts/test.py b/python/scripts/test.py index b38f4527..bbd9894d 100644 --- a/python/scripts/test.py +++ b/python/scripts/test.py @@ -8,6 +8,7 @@ from usearch.index import ( Index, + Indexes, SparseIndex, MetricKind, ScalarKind, @@ -247,6 +248,20 @@ def test_exact_recall( assert man == woman, "Stable marriage failed" +def test_indexes(): + ndim = 10 + index_a = Index(ndim=ndim) + index_b = Index(ndim=ndim) + + vectors = random_vectors(count=3, ndim=ndim) + index_a.add(42, vectors[0]) + index_b.add(43, vectors[1]) + + indexes = Indexes([index_a, index_b]) + matches = indexes.search(vectors[2], 10) + assert len(matches) == 2 + + @pytest.mark.parametrize("bits", dimensions) @pytest.mark.parametrize("metric", hash_metrics) @pytest.mark.parametrize("connectivity", connectivity_options) diff --git a/python/usearch/index.py b/python/usearch/index.py index 47d04964..10500744 100644 --- a/python/usearch/index.py +++ b/python/usearch/index.py @@ -12,9 +12,11 @@ import numpy as np from tqdm import tqdm -from usearch.compiled import Index as _CompiledIndex, IndexMetadata, IndexStats +from usearch.compiled import Index as _CompiledIndex +from usearch.compiled import Indexes as _CompiledIndexes from usearch.compiled import SparseIndex as _CompiledSetsIndex +from usearch.compiled import IndexMetadata, IndexStats from usearch.compiled import MetricKind, ScalarKind, MetricSignature from usearch.compiled import ( DEFAULT_CONNECTIVITY, @@ -103,6 +105,136 @@ def _normalize_metric(metric): return metric +def _search_in_compiled( + *, + compiled: Union[_CompiledIndex, _CompiledIndexes], + vectors: np.ndarray, + k: int, + threads: int, + exact: bool, + log: Union[str, bool], + batch_size: int, +) -> Union[Matches, BatchMatches]: + # + assert isinstance(vectors, np.ndarray), "Expects a NumPy array" + assert vectors.ndim == 1 or vectors.ndim == 2, "Expects a matrix or vector" + if vectors.ndim == 1: + vectors = vectors.reshape(1, len(vectors)) + count_vectors = vectors.shape[0] + + def distil_batch(batch_matches: BatchMatches) -> Union[BatchMatches, Matches]: + return batch_matches[0] if count_vectors == 1 else batch_matches + + if log and batch_size == 0: + batch_size = int(math.ceil(count_vectors / 100)) + + if batch_size: + tasks = [ + vectors[start_row : start_row + batch_size, :] + for start_row in range(0, count_vectors, batch_size) + ] + tasks_matches = [] + name = log if isinstance(log, str) else "Search" + pbar = tqdm( + tasks, + desc=name, + total=count_vectors, + unit="vector", + disable=log is False, + ) + for vectors in tasks: + tuple_ = compiled.search( + vectors, + k, + exact=exact, + threads=threads, + ) + tasks_matches.append(BatchMatches(*tuple_)) + pbar.update(vectors.shape[0]) + + pbar.close() + return distil_batch( + BatchMatches( + labels=np.vstack([m.labels for m in tasks_matches]), + distances=np.vstack([m.distances for m in tasks_matches]), + counts=np.concatenate([m.counts for m in tasks_matches], axis=None), + ) + ) + + else: + tuple_ = compiled.search( + vectors, + k, + exact=exact, + threads=threads, + ) + return distil_batch(BatchMatches(*tuple_)) + + +def _add_to_compiled( + *, + compiled, + labels, + vectors, + copy: bool, + threads: int, + log: Union[str, bool], + batch_size: int, +) -> Union[int, np.ndarray]: + assert isinstance(vectors, np.ndarray), "Expects a NumPy array" + assert vectors.ndim == 1 or vectors.ndim == 2, "Expects a matrix or vector" + if vectors.ndim == 1: + vectors = vectors.reshape(1, len(vectors)) + + # Validate or generate the labels + count_vectors = vectors.shape[0] + generate_labels = labels is None + if generate_labels: + start_id = len(compiled) + labels = np.arange(start_id, start_id + count_vectors, dtype=Label) + else: + if not isinstance(labels, Iterable): + assert count_vectors == 1, "Each vector must have a label" + labels = [labels] + labels = np.array(labels).astype(Label) + + assert len(labels) == count_vectors + + # If logging is requested, and batch size is undefined, set it to grow 1% at a time: + if log and batch_size == 0: + batch_size = int(math.ceil(count_vectors / 100)) + + # Split into batches and log progress, if needed + if batch_size: + labels = [ + labels[start_row : start_row + batch_size] + for start_row in range(0, count_vectors, batch_size) + ] + vectors = [ + vectors[start_row : start_row + batch_size, :] + for start_row in range(0, count_vectors, batch_size) + ] + tasks = zip(labels, vectors) + name = log if isinstance(log, str) else "Add" + pbar = tqdm( + tasks, + desc=name, + total=count_vectors, + unit="vector", + disable=log is False, + ) + for labels, vectors in tasks: + compiled.add(labels, vectors, copy=copy, threads=threads) + pbar.update(len(labels)) + + pbar.close() + + else: + compiled.add(labels, vectors, copy=copy, threads=threads) + + return labels + + @dataclass class Match: label: int @@ -352,58 +484,15 @@ def add( :return: Inserted label or labels :type: Union[int, np.ndarray] """ - assert isinstance(vectors, np.ndarray), "Expects a NumPy array" - assert vectors.ndim == 1 or vectors.ndim == 2, "Expects a matrix or vector" - if vectors.ndim == 1: - vectors = vectors.reshape(1, len(vectors)) - - # Validate or generate the labels - count_vectors = vectors.shape[0] - generate_labels = labels is None - if generate_labels: - start_id = len(self._compiled) - labels = np.arange(start_id, start_id + count_vectors, dtype=Label) - else: - if not isinstance(labels, Iterable): - assert count_vectors == 1, "Each vector must have a label" - labels = [labels] - labels = np.array(labels).astype(Label) - - assert len(labels) == count_vectors - - # If logging is requested, and batch size is undefined, set it to grow 1% at a time: - if log and batch_size == 0: - batch_size = int(math.ceil(count_vectors / 100)) - - # Split into batches and log progress, if needed - if batch_size: - labels = [ - labels[start_row : start_row + batch_size] - for start_row in range(0, count_vectors, batch_size) - ] - vectors = [ - vectors[start_row : start_row + batch_size, :] - for start_row in range(0, count_vectors, batch_size) - ] - tasks = zip(labels, vectors) - name = log if isinstance(log, str) else "Add" - pbar = tqdm( - tasks, - desc=name, - total=count_vectors, - unit="vector", - disable=log is False, - ) - for labels, vectors in tasks: - self._compiled.add(labels, vectors, copy=copy, threads=threads) - pbar.update(len(labels)) - - pbar.close() - - else: - self._compiled.add(labels, vectors, copy=copy, threads=threads) - - return labels + return _add_to_compiled( + compiled=self._compiled, + labels=labels, + vectors=vectors, + copy=copy, + threads=threads, + log=log, + batch_size=batch_size, + ) def search( self, @@ -434,59 +523,15 @@ def search( :rtype: Union[Matches, BatchMatches] """ - assert isinstance(vectors, np.ndarray), "Expects a NumPy array" - assert vectors.ndim == 1 or vectors.ndim == 2, "Expects a matrix or vector" - if vectors.ndim == 1: - vectors = vectors.reshape(1, len(vectors)) - count_vectors = vectors.shape[0] - - def distil_batch(batch_matches: BatchMatches) -> Union[BatchMatches, Matches]: - return batch_matches[0] if count_vectors == 1 else batch_matches - - if log and batch_size == 0: - batch_size = int(math.ceil(count_vectors / 100)) - - if batch_size: - tasks = [ - vectors[start_row : start_row + batch_size, :] - for start_row in range(0, count_vectors, batch_size) - ] - tasks_matches = [] - name = log if isinstance(log, str) else "Search" - pbar = tqdm( - tasks, - desc=name, - total=count_vectors, - unit="vector", - disable=log is False, - ) - for vectors in tasks: - tuple_ = self._compiled.search( - vectors, - k, - exact=exact, - threads=threads, - ) - tasks_matches.append(BatchMatches(*tuple_)) - pbar.update(vectors.shape[0]) - - pbar.close() - return distil_batch( - BatchMatches( - labels=np.vstack([m.labels for m in tasks_matches]), - distances=np.vstack([m.distances for m in tasks_matches]), - counts=np.concatenate([m.counts for m in tasks_matches], axis=None), - ) - ) - - else: - tuple_ = self._compiled.search( - vectors, - k, - exact=exact, - threads=threads, - ) - return distil_batch(BatchMatches(*tuple_)) + return _search_in_compiled( + compiled=self._compiled, + vectors=vectors, + k=k, + exact=exact, + threads=threads, + log=log, + batch_size=batch_size, + ) def remove( self, @@ -753,3 +798,31 @@ def _repr_pretty_(self) -> str: *level_stats, ] ) + + +class Indexes: + def __init__(self, indexes: Iterable[Index]) -> None: + self._compiled = _CompiledIndexes() + for index in indexes: + self._compiled.add(index._compiled) + + def add(self, index: Index): + self._compiled.add(index._compiled) + + def search( + self, + vectors, + k: int = 10, + *, + threads: int = 0, + exact: bool = False, + ): + return _search_in_compiled( + compiled=self._compiled, + vectors=vectors, + k=k, + exact=exact, + threads=threads, + log=False, + batch_size=None, + )