Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
691ec1b
Add missing header
jaepil Dec 7, 2023
3e0b125
Merge branch 'facebookresearch:main' into main
jaepil Dec 8, 2023
893832d
Merge branch 'facebookresearch:main' into main
jaepil Dec 12, 2023
b644791
Merge branch 'facebookresearch:main' into main
jaepil Dec 19, 2023
519d5a7
Merge branch 'facebookresearch:main' into main
jaepil Dec 30, 2023
f048829
Merge branch 'facebookresearch:main' into main
jaepil Jan 17, 2024
179aaaa
Merge branch 'facebookresearch:main' into main
jaepil Feb 3, 2024
d38f51e
Merge branch 'facebookresearch:main' into main
jaepil Feb 17, 2024
151f9d3
Merge branch 'facebookresearch:main' into main
jaepil Feb 26, 2024
862e83e
Merge branch 'facebookresearch:main' into main
jaepil Mar 2, 2024
e6711d4
Merge branch 'facebookresearch:main' into main
jaepil Mar 29, 2024
dbdc3df
Merge branch 'facebookresearch:main' into main
jaepil Apr 6, 2024
210a84a
Merge branch 'facebookresearch:main' into main
jaepil Apr 18, 2024
6c68ea0
Merge branch 'facebookresearch:main' into main
jaepil Apr 25, 2024
3419736
Merge branch 'facebookresearch:main' into main
jaepil May 2, 2024
7b9e426
Merge branch 'facebookresearch:main' into main
jaepil May 21, 2024
51be2a6
Merge branch 'facebookresearch:main' into main
jaepil May 24, 2024
75f3071
Merge branch 'facebookresearch:main' into main
jaepil May 31, 2024
777a2c5
Merge branch 'facebookresearch:main' into main
jaepil Jun 10, 2024
9eacaea
Merge branch 'facebookresearch:main' into main
jaepil Jun 22, 2024
484a6aa
Merge branch 'facebookresearch:main' into main
jaepil Jun 27, 2024
1ab48be
Add IDSelector::is_member(idx, dist) function
jaepil Jun 29, 2024
8a48dd3
Merge branch 'facebookresearch:main' into main
jaepil Jul 13, 2024
f4dcf6f
Change IDSelector interface
jaepil Jul 13, 2024
ef19a28
Change IDSelector interface
jaepil Jul 13, 2024
8d9f3ed
Change the default behavior of is_member(idx, d) function
jaepil Jul 13, 2024
515fef3
Minor change
jaepil Jul 13, 2024
ae3b977
Minor change
jaepil Jul 13, 2024
cc2313e
Implement is_member() for C, Python, and SIMD
jaepil Jul 15, 2024
9894759
Make clang-format happy
jaepil Jul 16, 2024
ac45fc6
Update IDSelector
jaepil Jul 16, 2024
937d560
Merge branch 'facebookresearch:main' into main
jaepil Jul 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions c_api/impl/AuxIndexStructures_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ int faiss_IDSelector_is_member(const FaissIDSelector* sel, idx_t id) {
return reinterpret_cast<const IDSelector*>(sel)->is_member(id);
}

int faiss_IDSelector_is_member_with_dist(
const FaissIDSelector* sel,
idx_t id,
float dist) {
return reinterpret_cast<const IDSelector*>(sel)->is_member(id, dist);
}

DEFINE_DESTRUCTOR(IDSelectorRange)

DEFINE_GETTER(IDSelectorRange, idx_t, imin)
Expand Down
4 changes: 4 additions & 0 deletions c_api/impl/AuxIndexStructures_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ FAISS_DECLARE_CLASS(IDSelector)
FAISS_DECLARE_DESTRUCTOR(IDSelector)

int faiss_IDSelector_is_member(const FaissIDSelector* sel, idx_t id);
int faiss_IDSelector_is_member_with_dist(
const FaissIDSelector* sel,
idx_t id,
float dist);

/** remove ids between [imni, imax) */
FAISS_DECLARE_CLASS(IDSelectorRange)
Expand Down
3 changes: 3 additions & 0 deletions faiss/IndexIDMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ struct IDSelectorTranslated : IDSelector {
bool is_member(idx_t id) const override {
return sel->is_member(id_map[id]);
}
bool is_member(idx_t id, std::optional<float> d) const override {
return sel->is_member(id_map[id], d);
}
};

} // namespace faiss
4 changes: 2 additions & 2 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ int search_from_candidates(
idx_t v1 = candidates.ids[i];
float d = candidates.dis[i];
FAISS_ASSERT(v1 >= 0);
if (!sel || sel->is_member(v1)) {
if (!sel || sel->is_member(v1, d)) {
if (d < threshold) {
if (res.add_result(d, v1)) {
threshold = res.threshold;
Expand Down Expand Up @@ -637,7 +637,7 @@ int search_from_candidates(
threshold = res.threshold;

auto add_to_heap = [&](const size_t idx, const float dis) {
if (!sel || sel->is_member(idx)) {
if (!sel || sel->is_member(idx, dis)) {
if (dis < threshold) {
if (res.add_result(dis, idx)) {
threshold = res.threshold;
Expand Down
12 changes: 8 additions & 4 deletions faiss/impl/IDSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ namespace faiss {
IDSelectorRange::IDSelectorRange(idx_t imin, idx_t imax, bool assume_sorted)
: imin(imin), imax(imax), assume_sorted(assume_sorted) {}

bool IDSelectorRange::is_member(idx_t id) const {
bool IDSelectorRange::is_member(idx_t id, std::optional<float> d) const {
(void)d;
return id >= imin && id < imax;
}

Expand Down Expand Up @@ -69,7 +70,8 @@ void IDSelectorRange::find_sorted_ids_bounds(

IDSelectorArray::IDSelectorArray(size_t n, const idx_t* ids) : n(n), ids(ids) {}

bool IDSelectorArray::is_member(idx_t id) const {
bool IDSelectorArray::is_member(idx_t id, std::optional<float> d) const {
(void)d;
for (idx_t i = 0; i < n; i++) {
if (ids[i] == id)
return true;
Expand Down Expand Up @@ -99,7 +101,8 @@ IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) {
}
}

bool IDSelectorBatch::is_member(idx_t i) const {
bool IDSelectorBatch::is_member(idx_t i, std::optional<float> d) const {
(void)d;
long im = i & mask;
if (!(bloom[im >> 3] & (1 << (im & 7)))) {
return 0;
Expand All @@ -114,7 +117,8 @@ bool IDSelectorBatch::is_member(idx_t i) const {
IDSelectorBitmap::IDSelectorBitmap(size_t n, const uint8_t* bitmap)
: n(n), bitmap(bitmap) {}

bool IDSelectorBitmap::is_member(idx_t ii) const {
bool IDSelectorBitmap::is_member(idx_t ii, std::optional<float> d) const {
(void)d;
uint64_t i = ii;
if ((i >> 3) >= n) {
return false;
Expand Down
48 changes: 43 additions & 5 deletions faiss/impl/IDSelector.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#pragma once

#include <optional>
#include <unordered_set>
#include <vector>

Expand All @@ -19,7 +20,14 @@ namespace faiss {

/** Encapsulates a set of ids to handle. */
struct IDSelector {
virtual bool is_member(idx_t id) const = 0;
virtual bool is_member(idx_t id) const {
(void)id;
return true;
}
virtual bool is_member(idx_t id, std::optional<float> d) const {
(void)d;
return is_member(id);
}
virtual ~IDSelector() {}
};

Expand All @@ -33,7 +41,10 @@ struct IDSelectorRange : IDSelector {

IDSelectorRange(idx_t imin, idx_t imax, bool assume_sorted = false);

bool is_member(idx_t id) const final;
bool is_member(idx_t id) const final {
return is_member(id, std::nullopt);
}
bool is_member(idx_t id, std::optional<float> d) const final;

/// for sorted ids, find the range of list indices where the valid ids are
/// stored
Expand Down Expand Up @@ -62,7 +73,10 @@ struct IDSelectorArray : IDSelector {
* IDSelectorArray's lifetime
*/
IDSelectorArray(size_t n, const idx_t* ids);
bool is_member(idx_t id) const final;
bool is_member(idx_t id) const final {
return is_member(id, std::nullopt);
}
bool is_member(idx_t id, std::optional<float> d) const final;
~IDSelectorArray() override {}
};

Expand Down Expand Up @@ -92,7 +106,10 @@ struct IDSelectorBatch : IDSelector {
* construction
*/
IDSelectorBatch(size_t n, const idx_t* indices);
bool is_member(idx_t id) const final;
bool is_member(idx_t id) const final {
return is_member(id, std::nullopt);
}
bool is_member(idx_t id, std::optional<float> d) const final;
~IDSelectorBatch() override {}
};

Expand All @@ -109,7 +126,10 @@ struct IDSelectorBitmap : IDSelector {
* (i%8) of bitmap[floor(i / 8)] is 1.
*/
IDSelectorBitmap(size_t n, const uint8_t* bitmap);
bool is_member(idx_t id) const final;
bool is_member(idx_t id) const final {
return is_member(id, std::nullopt);
}
bool is_member(idx_t id, std::optional<float> d) const final;
~IDSelectorBitmap() override {}
};

Expand All @@ -120,12 +140,21 @@ struct IDSelectorNot : IDSelector {
bool is_member(idx_t id) const final {
return !sel->is_member(id);
}
bool is_member(idx_t id, std::optional<float> d) const final {
return !sel->is_member(id, d);
}
virtual ~IDSelectorNot() {}
};

/// selects all entries (useful for benchmarking)
struct IDSelectorAll : IDSelector {
bool is_member(idx_t id) const final {
(void)id;
return true;
}
bool is_member(idx_t id, std::optional<float> d) const final {
(void)id;
(void)d;
return true;
}
virtual ~IDSelectorAll() {}
Expand All @@ -141,6 +170,9 @@ struct IDSelectorAnd : IDSelector {
bool is_member(idx_t id) const final {
return lhs->is_member(id) && rhs->is_member(id);
}
bool is_member(idx_t id, std::optional<float> d) const final {
return lhs->is_member(id, d) && rhs->is_member(id, d);
}
virtual ~IDSelectorAnd() {}
};

Expand All @@ -154,6 +186,9 @@ struct IDSelectorOr : IDSelector {
bool is_member(idx_t id) const final {
return lhs->is_member(id) || rhs->is_member(id);
}
bool is_member(idx_t id, std::optional<float> d) const final {
return lhs->is_member(id, d) || rhs->is_member(id, d);
}
virtual ~IDSelectorOr() {}
};

Expand All @@ -167,6 +202,9 @@ struct IDSelectorXOr : IDSelector {
bool is_member(idx_t id) const final {
return lhs->is_member(id) ^ rhs->is_member(id);
}
bool is_member(idx_t id, std::optional<float> d) const final {
return lhs->is_member(id, d) ^ rhs->is_member(id, d);
}
virtual ~IDSelectorXOr() {}
};

Expand Down
5 changes: 3 additions & 2 deletions faiss/impl/ResultHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <algorithm>
#include <iostream>
#include <optional>

namespace faiss {

Expand Down Expand Up @@ -62,8 +63,8 @@ struct BlockResultHandler {

virtual ~BlockResultHandler() {}

bool is_in_selection(idx_t i) const {
return !use_sel || sel->is_member(i);
bool is_in_selection(idx_t i, std::optional<float> d = std::nullopt) const {
return !use_sel || sel->is_member(i, d);
}
};

Expand Down
16 changes: 8 additions & 8 deletions faiss/impl/simd_result_handlers.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
int j = __builtin_ctz(lt_mask);
auto real_idx = this->adjust_id(b, j);
lt_mask -= 1 << j;
if (this->sel->is_member(real_idx)) {
T d = d32tab[j];
T d = d32tab[j];
if (this->sel->is_member(real_idx, d)) {
if (C::cmp(idis[q], d)) {
idis[q] = d;
ids[q] = real_idx;
Expand Down Expand Up @@ -367,8 +367,8 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
int j = __builtin_ctz(lt_mask);
auto real_idx = this->adjust_id(b, j);
lt_mask -= 1 << j;
if (this->sel->is_member(real_idx)) {
T dis = d32tab[j];
T dis = d32tab[j];
if (this->sel->is_member(real_idx, dis)) {
if (C::cmp(heap_dis[0], dis)) {
heap_replace_top<C>(
k, heap_dis, heap_ids, dis, real_idx);
Expand Down Expand Up @@ -479,8 +479,8 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
int j = __builtin_ctz(lt_mask);
auto real_idx = this->adjust_id(b, j);
lt_mask -= 1 << j;
if (this->sel->is_member(real_idx)) {
T dis = d32tab[j];
T dis = d32tab[j];
if (this->sel->is_member(real_idx, dis)) {
res.add(dis, real_idx);
}
}
Expand Down Expand Up @@ -602,8 +602,8 @@ struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
lt_mask -= 1 << j;

auto real_idx = this->adjust_id(b, j);
if (this->sel->is_member(real_idx)) {
T dis = d32tab[j];
T dis = d32tab[j];
if (this->sel->is_member(real_idx, dis)) {
n_per_query[q]++;
triplets.push_back({idx_t(q + q0), real_idx, dis});
}
Expand Down
18 changes: 18 additions & 0 deletions faiss/python/python_callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,24 @@ bool PyCallbackIDSelector::is_member(faiss::idx_t id) const {
return b;
}

bool PyCallbackIDSelector::is_member(faiss::idx_t id, std::optional<float> d)
const {
if (!d.has_value()) {
return is_member(id);
}

FAISS_THROW_IF_NOT((id >> 32) == 0);
PyThreadLock gil;
PyObject* result =
PyObject_CallFunction(callback, "(nf)", int(id), d.value());
if (result == nullptr) {
FAISS_THROW_MSG("propagate py error");
}
bool b = PyObject_IsTrue(result);
Py_DECREF(result);
return b;
}

PyCallbackIDSelector::~PyCallbackIDSelector() {
PyThreadLock gil;
Py_DECREF(callback);
Expand Down
1 change: 1 addition & 0 deletions faiss/python/python_callbacks.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct PyCallbackIDSelector : faiss::IDSelector {
explicit PyCallbackIDSelector(PyObject* callback);

bool is_member(faiss::idx_t id) const override;
bool is_member(faiss::idx_t id, std::optional<float> d) const override;

~PyCallbackIDSelector() override;
};
6 changes: 6 additions & 0 deletions faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ void exhaustive_inner_product_seq(
continue;
}
float ip = fvec_inner_product(x_i, y_j, d);
if (!res.is_in_selection(j, ip)) {
continue;
}
resi.add_result(ip, j);
}
resi.end();
Expand Down Expand Up @@ -189,6 +192,9 @@ void exhaustive_L2sqr_seq(
continue;
}
float disij = fvec_L2sqr(x_i, y_j, d);
if (!res.is_in_selection(j, disij)) {
continue;
}
resi.add_result(disij, j);
}
resi.end();
Expand Down