Skip to content

Commit

Permalink
Refactor: label_t to key_t
Browse files Browse the repository at this point in the history
Similarly, `id_t` changed `compressed_slot_t`, and vectors are now
values. Metrics can now differn on per-call basis, allowing truly
heterogenous lookups.
  • Loading branch information
ashvardanian committed Jul 28, 2023
1 parent e996b38 commit 0d6c800
Show file tree
Hide file tree
Showing 43 changed files with 2,341 additions and 1,964 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ matches: Matches = index.search(vector, 10)

assert len(index) == 1
assert len(matches) == 1
assert matches[0].label == 42
assert matches[0].key == 42
assert matches[0].distance <= 0.001
assert np.allclose(index[42], vector)
```
Expand Down Expand Up @@ -214,17 +214,17 @@ model = uform.get_model('unum-cloud/uform-vl-multilingual')
index = usearch.index.Index(ndim=256)

@server
def add(label: int, photo: pil.Image.Image):
def add(key: int, photo: pil.Image.Image):
image = model.preprocess_image(photo)
vector = model.encode_image(image).detach().numpy()
index.add(label, vector.flatten(), copy=True)
index.add(key, vector.flatten(), copy=True)

@server
def search(query: str) -> np.ndarray:
tokens = model.preprocess_text(query)
vector = model.encode_text(tokens).detach().numpy()
matches = index.search(vector.flatten(), 3)
return matches.labels
return matches.keys

server.run()
```
Expand Down Expand Up @@ -268,9 +268,9 @@ fingerprints = np.vstack([encoder.GetFingerprint(x) for x in molecules])
fingerprints = np.packbits(fingerprints, axis=1)

index = Index(ndim=2048, metric=MetricKind.Tanimoto)
labels = np.arange(len(molecules))
keys = np.arange(len(molecules))

index.add(labels, fingerprints)
index.add(keys, fingerprints)
matches = index.search(fingerprints, 10)
```

Expand Down
2 changes: 1 addition & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn main() {
println!("cargo:rerun-if-changed=rust/lib.rs");
println!("cargo:rerun-if-changed=rust/lib.cpp");
println!("cargo:rerun-if-changed=rust/lib.hpp");
println!("cargo:rerun-if-changed=include/index_punned_helpers.hpp");
println!("cargo:rerun-if-changed=include/index_plugins.hpp");
println!("cargo:rerun-if-changed=include/index_dense.hpp");
println!("cargo:rerun-if-changed=include/usearch/index.hpp");
}
71 changes: 36 additions & 35 deletions c/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ extern "C" {
using namespace unum::usearch;
using namespace unum;

using label_t = usearch_label_t;
using distance_t = usearch_distance_t;
using index_t = index_dense_gt<label_t>;
using add_result_t = index_t::add_result_t;
using search_result_t = index_t::search_result_t;
using labeling_result_t = index_t::labeling_result_t;
using serialization_result_t = index_t::serialization_result_t;
using index_t = index_dense_t;
using key_t = typename index_t::key_t;
using distance_t = typename index_t::distance_t;
using add_result_t = typename index_t::add_result_t;
using search_result_t = typename index_t::search_result_t;
using labeling_result_t = typename index_t::labeling_result_t;
using vector_view_t = span_gt<float>;

static_assert(std::is_same<usearch_key_t, key_t>::value, "Type mismatch between C and C++");

// helper functions that are not part of the C ABI
metric_kind_t to_native_metric(usearch_metric_kind_t kind) {
switch (kind) {
Expand Down Expand Up @@ -49,25 +50,25 @@ scalar_kind_t to_native_scalar(usearch_scalar_kind_t kind) {
}
}

add_result_t add_(index_t* index, usearch_label_t label, void const* vector, scalar_kind_t kind) {
add_result_t add_(index_t* index, usearch_key_t key, void const* vector, scalar_kind_t kind) {
switch (kind) {
case scalar_kind_t::f32_k: return index->add(label, (f32_t const*)vector);
case scalar_kind_t::f64_k: return index->add(label, (f64_t const*)vector);
case scalar_kind_t::f16_k: return index->add(label, (f16_t const*)vector);
case scalar_kind_t::f8_k: return index->add(label, (f8_bits_t const*)vector);
case scalar_kind_t::b1x8_k: return index->add(label, (b1x8_t const*)vector);
case scalar_kind_t::f32_k: return index->add(key, (f32_t const*)vector);
case scalar_kind_t::f64_k: return index->add(key, (f64_t const*)vector);
case scalar_kind_t::f16_k: return index->add(key, (f16_t const*)vector);
case scalar_kind_t::f8_k: return index->add(key, (f8_bits_t const*)vector);
case scalar_kind_t::b1x8_k: return index->add(key, (b1x8_t const*)vector);
default: return add_result_t{}.failed("Unknown scalar kind!");
}
}

bool get_(index_t* index, label_t label, void* vector, scalar_kind_t kind) {
bool get_(index_t* index, key_t key, void* vector, scalar_kind_t kind) {
switch (kind) {
case scalar_kind_t::f32_k: return index->get(label, (f32_t*)vector);
case scalar_kind_t::f64_k: return index->get(label, (f64_t*)vector);
case scalar_kind_t::f16_k: return index->get(label, (f16_t*)vector);
case scalar_kind_t::f8_k: return index->get(label, (f8_bits_t*)vector);
case scalar_kind_t::b1x8_k: return index->get(label, (b1x8_t*)vector);
default: return index->empty_search_result().failed("Unknown scalar kind!");
case scalar_kind_t::f32_k: return index->get(key, (f32_t*)vector);
case scalar_kind_t::f64_k: return index->get(key, (f64_t*)vector);
case scalar_kind_t::f16_k: return index->get(key, (f16_t*)vector);
case scalar_kind_t::f8_k: return index->get(key, (f8_bits_t*)vector);
case scalar_kind_t::b1x8_k: return index->get(key, (b1x8_t*)vector);
default: return search_result_t().failed("Unknown scalar kind!");
}
}

Expand All @@ -78,14 +79,14 @@ search_result_t search_(index_t* index, void const* vector, scalar_kind_t kind,
case scalar_kind_t::f16_k: return index->search((f16_t const*)vector, n);
case scalar_kind_t::f8_k: return index->search((f8_bits_t const*)vector, n);
case scalar_kind_t::b1x8_k: return index->search((b1x8_t const*)vector, n);
default: return index->empty_search_result().failed("Unknown scalar kind!");
default: return search_result_t().failed("Unknown scalar kind!");
}
}

index_dense_metric_t udf(metric_kind_t kind, usearch_metric_t raw_ptr) {
index_dense_metric_t result;
metric_punned_t udf(metric_kind_t kind, usearch_metric_t raw_ptr) {
metric_punned_t result;
result.kind_ = kind;
result.func_ = [raw_ptr](punned_vector_view_t a, punned_vector_view_t b) -> distance_t {
result.func_ = [raw_ptr](span_punned_t a, span_punned_t b) -> distance_t {
return raw_ptr((void const*)a.data(), (void const*)b.data());
};
return result;
Expand Down Expand Up @@ -154,21 +155,21 @@ USEARCH_EXPORT void usearch_reserve(usearch_index_t index, size_t capacity, usea
reinterpret_cast<index_t*>(index)->reserve(capacity);
}

USEARCH_EXPORT void usearch_add( //
usearch_index_t index, usearch_label_t label, void const* vector, usearch_scalar_kind_t kind, //
USEARCH_EXPORT void usearch_add( //
usearch_index_t index, usearch_key_t key, void const* vector, usearch_scalar_kind_t kind, //
usearch_error_t* error) {
add_result_t result = add_(reinterpret_cast<index_t*>(index), label, vector, to_native_scalar(kind));
add_result_t result = add_(reinterpret_cast<index_t*>(index), key, vector, to_native_scalar(kind));
if (!result)
*error = result.error.what();
}

USEARCH_EXPORT bool usearch_contains(usearch_index_t index, usearch_label_t label, usearch_error_t*) {
return reinterpret_cast<index_t*>(index)->contains(label);
USEARCH_EXPORT bool usearch_contains(usearch_index_t index, usearch_key_t key, usearch_error_t*) {
return reinterpret_cast<index_t*>(index)->contains(key);
}

USEARCH_EXPORT size_t usearch_search( //
usearch_index_t index, void const* vector, usearch_scalar_kind_t kind, size_t results_limit, //
usearch_label_t* found_labels, usearch_distance_t* found_distances, usearch_error_t* error) {
usearch_key_t* found_labels, usearch_distance_t* found_distances, usearch_error_t* error) {
search_result_t result = search_(reinterpret_cast<index_t*>(index), vector, to_native_scalar(kind), results_limit);
if (!result) {
*error = result.error.what();
Expand All @@ -178,14 +179,14 @@ USEARCH_EXPORT size_t usearch_search(
return result.dump_to(found_labels, found_distances);
}

USEARCH_EXPORT bool usearch_get( //
usearch_index_t index, usearch_label_t label, //
USEARCH_EXPORT bool usearch_get( //
usearch_index_t index, usearch_key_t key, //
void* vector, usearch_scalar_kind_t kind, usearch_error_t*) {
return get_(reinterpret_cast<index_t*>(index), label, vector, to_native_scalar(kind));
return get_(reinterpret_cast<index_t*>(index), key, vector, to_native_scalar(kind));
}

USEARCH_EXPORT bool usearch_remove(usearch_index_t index, usearch_label_t label, usearch_error_t* error) {
labeling_result_t result = reinterpret_cast<index_t*>(index)->remove(label);
USEARCH_EXPORT bool usearch_remove(usearch_index_t index, usearch_key_t key, usearch_error_t* error) {
labeling_result_t result = reinterpret_cast<index_t*>(index)->remove(key);
if (!result)
*error = result.error.what();
return result.completed;
Expand Down
40 changes: 20 additions & 20 deletions c/test.c
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ void test_add_vector(size_t vectors_count, size_t vector_dimension, float const*

// Add vectors
for (size_t i = 0; i < vectors_count; ++i) {
usearch_label_t label = i;
usearch_add(idx, label, data + i * vector_dimension, usearch_scalar_f32_k, &error);
usearch_key_t key = i;
usearch_add(idx, key, data + i * vector_dimension, usearch_scalar_f32_k, &error);
ASSERT(!error, error);
}

Expand All @@ -117,10 +117,10 @@ void test_add_vector(size_t vectors_count, size_t vector_dimension, float const*

// Check vectors in the index
for (size_t i = 0; i < vectors_count; ++i) {
usearch_label_t label = i;
ASSERT(usearch_contains(idx, label, &error), error);
usearch_key_t key = i;
ASSERT(usearch_contains(idx, key, &error), error);
}
ASSERT(!usearch_contains(idx, -1, &error), error); // Non existing label
ASSERT(!usearch_contains(idx, -1, &error), error); // Non existing key

usearch_free(idx, &error);
printf("Test: Add Vector - PASSED\n");
Expand All @@ -137,21 +137,21 @@ void test_find_vector(size_t vectors_count, size_t vector_dimension, float const

// Create result buffers
int results_count = 10;
usearch_label_t* labels = (usearch_label_t*)malloc(results_count * sizeof(usearch_label_t));
usearch_key_t* keys = (usearch_key_t*)malloc(results_count * sizeof(usearch_key_t));
float* distances = (float*)malloc(results_count * sizeof(float));
ASSERT(labels && distances, "Failed to allocate memory");
ASSERT(keys && distances, "Failed to allocate memory");

// Add vectors
for (size_t i = 0; i < vectors_count; ++i) {
usearch_label_t label = i;
usearch_add(idx, label, data + i * vector_dimension, usearch_scalar_f32_k, &error);
usearch_key_t key = i;
usearch_add(idx, key, data + i * vector_dimension, usearch_scalar_f32_k, &error);
ASSERT(!error, error);
}

// Find the vectors
for (size_t i = 0; i < vectors_count; i++) {
const void *query_vector = data + i * vector_dimension;
size_t found_count = usearch_search(idx, query_vector, usearch_scalar_f32_k, results_count, labels, distances, &error);
size_t found_count = usearch_search(idx, query_vector, usearch_scalar_f32_k, results_count, keys, distances, &error);
ASSERT(!error, error);
ASSERT(found_count = results_count, "Vector is missing");
}
Expand All @@ -171,15 +171,15 @@ void test_remove_vector(size_t vectors_count, size_t vector_dimension, float con

// Add vectors
for (size_t i = 0; i < vectors_count; ++i) {
usearch_label_t label = i;
usearch_add(idx, label, data + i * vector_dimension, usearch_scalar_f32_k, &error);
usearch_key_t key = i;
usearch_add(idx, key, data + i * vector_dimension, usearch_scalar_f32_k, &error);
ASSERT(!error, error);
}

// Remove the vectors
for (size_t i = 0; i < vectors_count; i++) {
usearch_label_t label = i;
usearch_remove(idx, label, &error);
usearch_key_t key = i;
usearch_remove(idx, key, &error);
ASSERT(error, "Currently, Remove is not supported");
}

Expand All @@ -198,8 +198,8 @@ void test_save_load(size_t vectors_count, size_t vector_dimension, float const*

// Add vectors
for (size_t i = 0; i < vectors_count; ++i) {
usearch_label_t label = i;
usearch_add(idx, label, data + i * vector_dimension, usearch_scalar_f32_k, &error);
usearch_key_t key = i;
usearch_add(idx, key, data + i * vector_dimension, usearch_scalar_f32_k, &error);
ASSERT(!error, error);
}

Expand All @@ -224,8 +224,8 @@ void test_save_load(size_t vectors_count, size_t vector_dimension, float const*

// Check vectors in the index
for (size_t i = 0; i < vectors_count; ++i) {
usearch_label_t label = i;
ASSERT(usearch_contains(idx, label, &error), error);
usearch_key_t key = i;
ASSERT(usearch_contains(idx, key, &error), error);
}

usearch_free(idx, &error);
Expand All @@ -243,8 +243,8 @@ void test_view(size_t vectors_count, size_t vector_dimension, float const* data)

// Add vectors
for (size_t i = 0; i < vectors_count; ++i) {
usearch_label_t label = i;
usearch_add(idx, label, data + i * vector_dimension, usearch_scalar_f32_k, &error);
usearch_key_t key = i;
usearch_add(idx, key, data + i * vector_dimension, usearch_scalar_f32_k, &error);
ASSERT(!error, error);
}

Expand Down
30 changes: 15 additions & 15 deletions c/usearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ extern "C" {
#include <stdint.h> // `size_t`

USEARCH_EXPORT typedef void* usearch_index_t;
USEARCH_EXPORT typedef uint32_t usearch_label_t;
USEARCH_EXPORT typedef uint32_t usearch_key_t;
USEARCH_EXPORT typedef float usearch_distance_t;
USEARCH_EXPORT typedef char const* usearch_error_t;

Expand Down Expand Up @@ -120,51 +120,51 @@ USEARCH_EXPORT size_t usearch_connectivity(usearch_index_t, usearch_error_t* err
USEARCH_EXPORT void usearch_reserve(usearch_index_t, size_t capacity, usearch_error_t* error);

/**
* @brief Adds a vector with a label to the index.
* @param label The label associated with the vector.
* @brief Adds a vector with a key to the index.
* @param key The key associated with the vector.
* @param vector Pointer to the vector data.
* @param vector_kind The scalar type used in the vector data.
* @param[out] error Pointer to a string where the error message will be stored, if an error occurs.
*/
USEARCH_EXPORT void usearch_add( //
usearch_index_t, usearch_label_t label, //
usearch_index_t, usearch_key_t key, //
void const* vector, usearch_scalar_kind_t vector_kind, usearch_error_t* error);

/**
* @brief Checks if the index contains a vector with a specific label.
* @param label The label to be checked.
* @brief Checks if the index contains a vector with a specific key.
* @param key The key to be checked.
* @param[out] error Pointer to a string where the error message will be stored, if an error occurs.
* @return `true` if the index contains the vector with the given label, `false` otherwise.
* @return `true` if the index contains the vector with the given key, `false` otherwise.
*/
USEARCH_EXPORT bool usearch_contains(usearch_index_t, usearch_label_t, usearch_error_t* error);
USEARCH_EXPORT bool usearch_contains(usearch_index_t, usearch_key_t, usearch_error_t* error);

/**
* @brief Performs k-Approximate Nearest Neighbors Search.
* @return Number of found matches.
*/
USEARCH_EXPORT size_t usearch_search( //
usearch_index_t, void const* query_vector, usearch_scalar_kind_t query_kind, size_t results_limit, //
usearch_label_t* found_labels, usearch_distance_t* found_distances, usearch_error_t* error);
usearch_key_t* found_labels, usearch_distance_t* found_distances, usearch_error_t* error);

/**
* @brief Retrieves the vector associated with the given label from the index.
* @param label The label of the vector to retrieve.
* @brief Retrieves the vector associated with the given key from the index.
* @param key The key of the vector to retrieve.
* @param[out] vector Pointer to the memory where the vector data will be copied.
* @param vector_kind The scalar type used in the vector data.
* @param[out] error Pointer to a string where the error message will be stored, if an error occurs.
* @return `true` if the vector is successfully retrieved, `false` if the vector is not found.
*/
USEARCH_EXPORT bool usearch_get( //
usearch_index_t, usearch_label_t label, //
usearch_index_t, usearch_key_t key, //
void* vector, usearch_scalar_kind_t vector_kind, usearch_error_t* error);

/**
* @brief Removes the vector associated with the given label from the index.
* @param label The label of the vector to be removed.
* @brief Removes the vector associated with the given key from the index.
* @param key The key of the vector to be removed.
* @param[out] error Pointer to a string where the error message will be stored, if an error occurs.
* @return `true` if the vector is successfully removed, `false` if the vector is not found.
*/
USEARCH_EXPORT bool usearch_remove(usearch_index_t, usearch_label_t label, usearch_error_t* error);
USEARCH_EXPORT bool usearch_remove(usearch_index_t, usearch_key_t key, usearch_error_t* error);

#ifdef __cplusplus
}
Expand Down
Loading

0 comments on commit 0d6c800

Please sign in to comment.