diff --git a/c_api/impl/AuxIndexStructures_c.cpp b/c_api/impl/AuxIndexStructures_c.cpp index c19c1b63be..ff78cc3525 100644 --- a/c_api/impl/AuxIndexStructures_c.cpp +++ b/c_api/impl/AuxIndexStructures_c.cpp @@ -88,6 +88,13 @@ int faiss_IDSelector_is_member(const FaissIDSelector* sel, idx_t id) { return reinterpret_cast(sel)->is_member(id); } +int faiss_IDSelector_is_member_with_dist( + const FaissIDSelector* sel, + idx_t id, + float dist) { + return reinterpret_cast(sel)->is_member(id, dist); +} + DEFINE_DESTRUCTOR(IDSelectorRange) DEFINE_GETTER(IDSelectorRange, idx_t, imin) diff --git a/c_api/impl/AuxIndexStructures_c.h b/c_api/impl/AuxIndexStructures_c.h index dba3026980..cd1306b7fa 100644 --- a/c_api/impl/AuxIndexStructures_c.h +++ b/c_api/impl/AuxIndexStructures_c.h @@ -53,6 +53,10 @@ FAISS_DECLARE_CLASS(IDSelector) FAISS_DECLARE_DESTRUCTOR(IDSelector) int faiss_IDSelector_is_member(const FaissIDSelector* sel, idx_t id); +int faiss_IDSelector_is_member_with_dist( + const FaissIDSelector* sel, + idx_t id, + float dist); /** remove ids between [imni, imax) */ FAISS_DECLARE_CLASS(IDSelectorRange) diff --git a/faiss/IndexIDMap.h b/faiss/IndexIDMap.h index 2d16412301..0047dc4b26 100644 --- a/faiss/IndexIDMap.h +++ b/faiss/IndexIDMap.h @@ -122,6 +122,9 @@ struct IDSelectorTranslated : IDSelector { bool is_member(idx_t id) const override { return sel->is_member(id_map[id]); } + bool is_member(idx_t id, std::optional d) const override { + return sel->is_member(id_map[id], d); + } }; } // namespace faiss diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 3ba5f72f68..38dbddb894 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -568,7 +568,7 @@ int search_from_candidates( idx_t v1 = candidates.ids[i]; float d = candidates.dis[i]; FAISS_ASSERT(v1 >= 0); - if (!sel || sel->is_member(v1)) { + if (!sel || sel->is_member(v1, d)) { if (d < threshold) { if (res.add_result(d, v1)) { threshold = res.threshold; @@ -637,7 +637,7 @@ int search_from_candidates( threshold = res.threshold; auto add_to_heap = [&](const size_t idx, const float dis) { - if (!sel || sel->is_member(idx)) { + if (!sel || sel->is_member(idx, dis)) { if (dis < threshold) { if (res.add_result(dis, idx)) { threshold = res.threshold; diff --git a/faiss/impl/IDSelector.cpp b/faiss/impl/IDSelector.cpp index e4a4bba967..214cf0413f 100644 --- a/faiss/impl/IDSelector.cpp +++ b/faiss/impl/IDSelector.cpp @@ -17,7 +17,8 @@ namespace faiss { IDSelectorRange::IDSelectorRange(idx_t imin, idx_t imax, bool assume_sorted) : imin(imin), imax(imax), assume_sorted(assume_sorted) {} -bool IDSelectorRange::is_member(idx_t id) const { +bool IDSelectorRange::is_member(idx_t id, std::optional d) const { + (void)d; return id >= imin && id < imax; } @@ -69,7 +70,8 @@ void IDSelectorRange::find_sorted_ids_bounds( IDSelectorArray::IDSelectorArray(size_t n, const idx_t* ids) : n(n), ids(ids) {} -bool IDSelectorArray::is_member(idx_t id) const { +bool IDSelectorArray::is_member(idx_t id, std::optional d) const { + (void)d; for (idx_t i = 0; i < n; i++) { if (ids[i] == id) return true; @@ -99,7 +101,8 @@ IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) { } } -bool IDSelectorBatch::is_member(idx_t i) const { +bool IDSelectorBatch::is_member(idx_t i, std::optional d) const { + (void)d; long im = i & mask; if (!(bloom[im >> 3] & (1 << (im & 7)))) { return 0; @@ -114,7 +117,8 @@ bool IDSelectorBatch::is_member(idx_t i) const { IDSelectorBitmap::IDSelectorBitmap(size_t n, const uint8_t* bitmap) : n(n), bitmap(bitmap) {} -bool IDSelectorBitmap::is_member(idx_t ii) const { +bool IDSelectorBitmap::is_member(idx_t ii, std::optional d) const { + (void)d; uint64_t i = ii; if ((i >> 3) >= n) { return false; diff --git a/faiss/impl/IDSelector.h b/faiss/impl/IDSelector.h index 7760d2a9fa..de8b6aa185 100644 --- a/faiss/impl/IDSelector.h +++ b/faiss/impl/IDSelector.h @@ -7,6 +7,7 @@ #pragma once +#include #include #include @@ -19,7 +20,14 @@ namespace faiss { /** Encapsulates a set of ids to handle. */ struct IDSelector { - virtual bool is_member(idx_t id) const = 0; + virtual bool is_member(idx_t id) const { + (void)id; + return true; + } + virtual bool is_member(idx_t id, std::optional d) const { + (void)d; + return is_member(id); + } virtual ~IDSelector() {} }; @@ -33,7 +41,10 @@ struct IDSelectorRange : IDSelector { IDSelectorRange(idx_t imin, idx_t imax, bool assume_sorted = false); - bool is_member(idx_t id) const final; + bool is_member(idx_t id) const final { + return is_member(id, std::nullopt); + } + bool is_member(idx_t id, std::optional d) const final; /// for sorted ids, find the range of list indices where the valid ids are /// stored @@ -62,7 +73,10 @@ struct IDSelectorArray : IDSelector { * IDSelectorArray's lifetime */ IDSelectorArray(size_t n, const idx_t* ids); - bool is_member(idx_t id) const final; + bool is_member(idx_t id) const final { + return is_member(id, std::nullopt); + } + bool is_member(idx_t id, std::optional d) const final; ~IDSelectorArray() override {} }; @@ -92,7 +106,10 @@ struct IDSelectorBatch : IDSelector { * construction */ IDSelectorBatch(size_t n, const idx_t* indices); - bool is_member(idx_t id) const final; + bool is_member(idx_t id) const final { + return is_member(id, std::nullopt); + } + bool is_member(idx_t id, std::optional d) const final; ~IDSelectorBatch() override {} }; @@ -109,7 +126,10 @@ struct IDSelectorBitmap : IDSelector { * (i%8) of bitmap[floor(i / 8)] is 1. */ IDSelectorBitmap(size_t n, const uint8_t* bitmap); - bool is_member(idx_t id) const final; + bool is_member(idx_t id) const final { + return is_member(id, std::nullopt); + } + bool is_member(idx_t id, std::optional d) const final; ~IDSelectorBitmap() override {} }; @@ -120,12 +140,21 @@ struct IDSelectorNot : IDSelector { bool is_member(idx_t id) const final { return !sel->is_member(id); } + bool is_member(idx_t id, std::optional d) const final { + return !sel->is_member(id, d); + } virtual ~IDSelectorNot() {} }; /// selects all entries (useful for benchmarking) struct IDSelectorAll : IDSelector { bool is_member(idx_t id) const final { + (void)id; + return true; + } + bool is_member(idx_t id, std::optional d) const final { + (void)id; + (void)d; return true; } virtual ~IDSelectorAll() {} @@ -141,6 +170,9 @@ struct IDSelectorAnd : IDSelector { bool is_member(idx_t id) const final { return lhs->is_member(id) && rhs->is_member(id); } + bool is_member(idx_t id, std::optional d) const final { + return lhs->is_member(id, d) && rhs->is_member(id, d); + } virtual ~IDSelectorAnd() {} }; @@ -154,6 +186,9 @@ struct IDSelectorOr : IDSelector { bool is_member(idx_t id) const final { return lhs->is_member(id) || rhs->is_member(id); } + bool is_member(idx_t id, std::optional d) const final { + return lhs->is_member(id, d) || rhs->is_member(id, d); + } virtual ~IDSelectorOr() {} }; @@ -167,6 +202,9 @@ struct IDSelectorXOr : IDSelector { bool is_member(idx_t id) const final { return lhs->is_member(id) ^ rhs->is_member(id); } + bool is_member(idx_t id, std::optional d) const final { + return lhs->is_member(id, d) ^ rhs->is_member(id, d); + } virtual ~IDSelectorXOr() {} }; diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h index 3116eb24df..f65025fee4 100644 --- a/faiss/impl/ResultHandler.h +++ b/faiss/impl/ResultHandler.h @@ -19,6 +19,7 @@ #include #include +#include namespace faiss { @@ -62,8 +63,8 @@ struct BlockResultHandler { virtual ~BlockResultHandler() {} - bool is_in_selection(idx_t i) const { - return !use_sel || sel->is_member(i); + bool is_in_selection(idx_t i, std::optional d = std::nullopt) const { + return !use_sel || sel->is_member(i, d); } }; diff --git a/faiss/impl/simd_result_handlers.h b/faiss/impl/simd_result_handlers.h index 2fa18fa340..ee5d755d65 100644 --- a/faiss/impl/simd_result_handlers.h +++ b/faiss/impl/simd_result_handlers.h @@ -271,8 +271,8 @@ struct SingleResultHandler : ResultHandlerCompare { int j = __builtin_ctz(lt_mask); auto real_idx = this->adjust_id(b, j); lt_mask -= 1 << j; - if (this->sel->is_member(real_idx)) { - T d = d32tab[j]; + T d = d32tab[j]; + if (this->sel->is_member(real_idx, d)) { if (C::cmp(idis[q], d)) { idis[q] = d; ids[q] = real_idx; @@ -367,8 +367,8 @@ struct HeapHandler : ResultHandlerCompare { int j = __builtin_ctz(lt_mask); auto real_idx = this->adjust_id(b, j); lt_mask -= 1 << j; - if (this->sel->is_member(real_idx)) { - T dis = d32tab[j]; + T dis = d32tab[j]; + if (this->sel->is_member(real_idx, dis)) { if (C::cmp(heap_dis[0], dis)) { heap_replace_top( k, heap_dis, heap_ids, dis, real_idx); @@ -479,8 +479,8 @@ struct ReservoirHandler : ResultHandlerCompare { int j = __builtin_ctz(lt_mask); auto real_idx = this->adjust_id(b, j); lt_mask -= 1 << j; - if (this->sel->is_member(real_idx)) { - T dis = d32tab[j]; + T dis = d32tab[j]; + if (this->sel->is_member(real_idx, dis)) { res.add(dis, real_idx); } } @@ -602,8 +602,8 @@ struct RangeHandler : ResultHandlerCompare { lt_mask -= 1 << j; auto real_idx = this->adjust_id(b, j); - if (this->sel->is_member(real_idx)) { - T dis = d32tab[j]; + T dis = d32tab[j]; + if (this->sel->is_member(real_idx, dis)) { n_per_query[q]++; triplets.push_back({idx_t(q + q0), real_idx, dis}); } diff --git a/faiss/python/python_callbacks.cpp b/faiss/python/python_callbacks.cpp index 06b5c18cfc..62d4a31382 100644 --- a/faiss/python/python_callbacks.cpp +++ b/faiss/python/python_callbacks.cpp @@ -130,6 +130,24 @@ bool PyCallbackIDSelector::is_member(faiss::idx_t id) const { return b; } +bool PyCallbackIDSelector::is_member(faiss::idx_t id, std::optional d) + const { + if (!d.has_value()) { + return is_member(id); + } + + FAISS_THROW_IF_NOT((id >> 32) == 0); + PyThreadLock gil; + PyObject* result = + PyObject_CallFunction(callback, "(nf)", int(id), d.value()); + if (result == nullptr) { + FAISS_THROW_MSG("propagate py error"); + } + bool b = PyObject_IsTrue(result); + Py_DECREF(result); + return b; +} + PyCallbackIDSelector::~PyCallbackIDSelector() { PyThreadLock gil; Py_DECREF(callback); diff --git a/faiss/python/python_callbacks.h b/faiss/python/python_callbacks.h index 421239bd0f..f3af0fa1c0 100644 --- a/faiss/python/python_callbacks.h +++ b/faiss/python/python_callbacks.h @@ -55,6 +55,7 @@ struct PyCallbackIDSelector : faiss::IDSelector { explicit PyCallbackIDSelector(PyObject* callback); bool is_member(faiss::idx_t id) const override; + bool is_member(faiss::idx_t id, std::optional d) const override; ~PyCallbackIDSelector() override; }; diff --git a/faiss/utils/distances.cpp b/faiss/utils/distances.cpp index 1506bee5cf..99d7ab413a 100644 --- a/faiss/utils/distances.cpp +++ b/faiss/utils/distances.cpp @@ -157,6 +157,9 @@ void exhaustive_inner_product_seq( continue; } float ip = fvec_inner_product(x_i, y_j, d); + if (!res.is_in_selection(j, ip)) { + continue; + } resi.add_result(ip, j); } resi.end(); @@ -189,6 +192,9 @@ void exhaustive_L2sqr_seq( continue; } float disij = fvec_L2sqr(x_i, y_j, d); + if (!res.is_in_selection(j, disij)) { + continue; + } resi.add_result(disij, j); } resi.end();