From 691ec1bf304a157ba02023a59eaf10f1fe8931e0 Mon Sep 17 00:00:00 2001 From: Jaepil Jeong Date: Thu, 7 Dec 2023 12:11:01 +0900 Subject: [PATCH 01/10] Add missing header --- faiss/utils/WorkerThread.h | 1 + 1 file changed, 1 insertion(+) diff --git a/faiss/utils/WorkerThread.h b/faiss/utils/WorkerThread.h index 72529be0a3..3f2377eba2 100644 --- a/faiss/utils/WorkerThread.h +++ b/faiss/utils/WorkerThread.h @@ -9,6 +9,7 @@ #include #include +#include #include #include From 1ab48be173c183f2648bf91526abbe844a0b1476 Mon Sep 17 00:00:00 2001 From: Jaepil Jeong Date: Sat, 29 Jun 2024 23:33:05 +0900 Subject: [PATCH 02/10] Add IDSelector::is_member(idx, dist) function --- faiss/IndexIDMap.h | 3 +++ faiss/impl/HNSW.cpp | 4 ++-- faiss/impl/IDSelector.h | 5 +++++ faiss/utils/distances.cpp | 6 ++++++ 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/faiss/IndexIDMap.h b/faiss/IndexIDMap.h index 2d16412301..c79c9f455b 100644 --- a/faiss/IndexIDMap.h +++ b/faiss/IndexIDMap.h @@ -122,6 +122,9 @@ struct IDSelectorTranslated : IDSelector { bool is_member(idx_t id) const override { return sel->is_member(id_map[id]); } + bool is_member(idx_t id, float distance) const override { + return sel->is_member(id_map[id], distance); + } }; } // namespace faiss diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 3ba5f72f68..38dbddb894 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -568,7 +568,7 @@ int search_from_candidates( idx_t v1 = candidates.ids[i]; float d = candidates.dis[i]; FAISS_ASSERT(v1 >= 0); - if (!sel || sel->is_member(v1)) { + if (!sel || sel->is_member(v1, d)) { if (d < threshold) { if (res.add_result(d, v1)) { threshold = res.threshold; @@ -637,7 +637,7 @@ int search_from_candidates( threshold = res.threshold; auto add_to_heap = [&](const size_t idx, const float dis) { - if (!sel || sel->is_member(idx)) { + if (!sel || sel->is_member(idx, dis)) { if (dis < threshold) { if (res.add_result(dis, idx)) { threshold = res.threshold; diff --git a/faiss/impl/IDSelector.h b/faiss/impl/IDSelector.h index 7760d2a9fa..d14d9baa6f 100644 --- a/faiss/impl/IDSelector.h +++ b/faiss/impl/IDSelector.h @@ -20,6 +20,11 @@ namespace faiss { /** Encapsulates a set of ids to handle. */ struct IDSelector { virtual bool is_member(idx_t id) const = 0; + virtual bool is_member(idx_t id, float distance) const { + (void)id; + (void)distance; + return true; + } virtual ~IDSelector() {} }; diff --git a/faiss/utils/distances.cpp b/faiss/utils/distances.cpp index 74b56bcc87..d4e4502e4f 100644 --- a/faiss/utils/distances.cpp +++ b/faiss/utils/distances.cpp @@ -160,6 +160,9 @@ void exhaustive_inner_product_seq( continue; } float ip = fvec_inner_product(x_i, y_j, d); + if (use_sel && !sel->is_member(j, 1.f - ip)) { + continue; + } resi.add_result(ip, j); } resi.end(); @@ -195,6 +198,9 @@ void exhaustive_L2sqr_seq( continue; } float disij = fvec_L2sqr(x_i, y_j, d); + if (use_sel && !sel->is_member(j, disij)) { + continue; + } resi.add_result(disij, j); } resi.end(); From f4dcf6f8cc06fd02794e0926a9e59f95c2c50247 Mon Sep 17 00:00:00 2001 From: Jaepil Jeong Date: Sat, 13 Jul 2024 15:17:43 +0900 Subject: [PATCH 03/10] Change IDSelector interface --- faiss/IndexIDMap.h | 9 ++++---- faiss/impl/IDSelector.cpp | 12 ++++++---- faiss/impl/IDSelector.h | 47 ++++++++++++++++++++++++--------------- faiss/utils/distances.cpp | 2 +- 4 files changed, 42 insertions(+), 28 deletions(-) diff --git a/faiss/IndexIDMap.h b/faiss/IndexIDMap.h index c79c9f455b..a145c1bdfe 100644 --- a/faiss/IndexIDMap.h +++ b/faiss/IndexIDMap.h @@ -119,11 +119,10 @@ struct IDSelectorTranslated : IDSelector { IDSelectorTranslated(IndexIDMap& index_idmap, const IDSelector* sel) : id_map(index_idmap.id_map), sel(sel) {} - bool is_member(idx_t id) const override { - return sel->is_member(id_map[id]); - } - bool is_member(idx_t id, float distance) const override { - return sel->is_member(id_map[id], distance); + bool is_member( + idx_t id, + std::optional d = std::nullopt) const override { + return sel->is_member(id_map[id], d); } }; diff --git a/faiss/impl/IDSelector.cpp b/faiss/impl/IDSelector.cpp index e4a4bba967..214cf0413f 100644 --- a/faiss/impl/IDSelector.cpp +++ b/faiss/impl/IDSelector.cpp @@ -17,7 +17,8 @@ namespace faiss { IDSelectorRange::IDSelectorRange(idx_t imin, idx_t imax, bool assume_sorted) : imin(imin), imax(imax), assume_sorted(assume_sorted) {} -bool IDSelectorRange::is_member(idx_t id) const { +bool IDSelectorRange::is_member(idx_t id, std::optional d) const { + (void)d; return id >= imin && id < imax; } @@ -69,7 +70,8 @@ void IDSelectorRange::find_sorted_ids_bounds( IDSelectorArray::IDSelectorArray(size_t n, const idx_t* ids) : n(n), ids(ids) {} -bool IDSelectorArray::is_member(idx_t id) const { +bool IDSelectorArray::is_member(idx_t id, std::optional d) const { + (void)d; for (idx_t i = 0; i < n; i++) { if (ids[i] == id) return true; @@ -99,7 +101,8 @@ IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) { } } -bool IDSelectorBatch::is_member(idx_t i) const { +bool IDSelectorBatch::is_member(idx_t i, std::optional d) const { + (void)d; long im = i & mask; if (!(bloom[im >> 3] & (1 << (im & 7)))) { return 0; @@ -114,7 +117,8 @@ bool IDSelectorBatch::is_member(idx_t i) const { IDSelectorBitmap::IDSelectorBitmap(size_t n, const uint8_t* bitmap) : n(n), bitmap(bitmap) {} -bool IDSelectorBitmap::is_member(idx_t ii) const { +bool IDSelectorBitmap::is_member(idx_t ii, std::optional d) const { + (void)d; uint64_t i = ii; if ((i >> 3) >= n) { return false; diff --git a/faiss/impl/IDSelector.h b/faiss/impl/IDSelector.h index d14d9baa6f..1ba545d1b2 100644 --- a/faiss/impl/IDSelector.h +++ b/faiss/impl/IDSelector.h @@ -7,6 +7,7 @@ #pragma once +#include #include #include @@ -19,12 +20,9 @@ namespace faiss { /** Encapsulates a set of ids to handle. */ struct IDSelector { - virtual bool is_member(idx_t id) const = 0; - virtual bool is_member(idx_t id, float distance) const { - (void)id; - (void)distance; - return true; - } + virtual bool is_member( + idx_t id, + std::optional d = std::nullopt) const = 0; virtual ~IDSelector() {} }; @@ -38,7 +36,7 @@ struct IDSelectorRange : IDSelector { IDSelectorRange(idx_t imin, idx_t imax, bool assume_sorted = false); - bool is_member(idx_t id) const final; + bool is_member(idx_t id, std::optional d = std::nullopt) const final; /// for sorted ids, find the range of list indices where the valid ids are /// stored @@ -67,7 +65,7 @@ struct IDSelectorArray : IDSelector { * IDSelectorArray's lifetime */ IDSelectorArray(size_t n, const idx_t* ids); - bool is_member(idx_t id) const final; + bool is_member(idx_t id, std::optional d = std::nullopt) const final; ~IDSelectorArray() override {} }; @@ -97,7 +95,7 @@ struct IDSelectorBatch : IDSelector { * construction */ IDSelectorBatch(size_t n, const idx_t* indices); - bool is_member(idx_t id) const final; + bool is_member(idx_t id, std::optional d = std::nullopt) const final; ~IDSelectorBatch() override {} }; @@ -114,7 +112,7 @@ struct IDSelectorBitmap : IDSelector { * (i%8) of bitmap[floor(i / 8)] is 1. */ IDSelectorBitmap(size_t n, const uint8_t* bitmap); - bool is_member(idx_t id) const final; + bool is_member(idx_t id, std::optional d = std::nullopt) const final; ~IDSelectorBitmap() override {} }; @@ -122,7 +120,10 @@ struct IDSelectorBitmap : IDSelector { struct IDSelectorNot : IDSelector { const IDSelector* sel; IDSelectorNot(const IDSelector* sel) : sel(sel) {} - bool is_member(idx_t id) const final { + bool is_member( + idx_t id, + std::optional d = std::nullopt) const final { + (void)d; return !sel->is_member(id); } virtual ~IDSelectorNot() {} @@ -130,7 +131,11 @@ struct IDSelectorNot : IDSelector { /// selects all entries (useful for benchmarking) struct IDSelectorAll : IDSelector { - bool is_member(idx_t id) const final { + bool is_member( + idx_t id, + std::optional d = std::nullopt) const final { + (void)id; + (void)d; return true; } virtual ~IDSelectorAll() {} @@ -143,8 +148,10 @@ struct IDSelectorAnd : IDSelector { const IDSelector* rhs; IDSelectorAnd(const IDSelector* lhs, const IDSelector* rhs) : lhs(lhs), rhs(rhs) {} - bool is_member(idx_t id) const final { - return lhs->is_member(id) && rhs->is_member(id); + bool is_member( + idx_t id, + std::optional d = std::nullopt) const final { + return lhs->is_member(id, d) && rhs->is_member(id, d); } virtual ~IDSelectorAnd() {} }; @@ -156,8 +163,10 @@ struct IDSelectorOr : IDSelector { const IDSelector* rhs; IDSelectorOr(const IDSelector* lhs, const IDSelector* rhs) : lhs(lhs), rhs(rhs) {} - bool is_member(idx_t id) const final { - return lhs->is_member(id) || rhs->is_member(id); + bool is_member( + idx_t id, + std::optional d = std::nullopt) const final { + return lhs->is_member(id, d) || rhs->is_member(id, d); } virtual ~IDSelectorOr() {} }; @@ -169,8 +178,10 @@ struct IDSelectorXOr : IDSelector { const IDSelector* rhs; IDSelectorXOr(const IDSelector* lhs, const IDSelector* rhs) : lhs(lhs), rhs(rhs) {} - bool is_member(idx_t id) const final { - return lhs->is_member(id) ^ rhs->is_member(id); + bool is_member( + idx_t id, + std::optional d = std::nullopt) const final { + return lhs->is_member(id, d) ^ rhs->is_member(id, d); } virtual ~IDSelectorXOr() {} }; diff --git a/faiss/utils/distances.cpp b/faiss/utils/distances.cpp index de1c093c8b..a114afaa00 100644 --- a/faiss/utils/distances.cpp +++ b/faiss/utils/distances.cpp @@ -157,7 +157,7 @@ void exhaustive_inner_product_seq( continue; } float ip = fvec_inner_product(x_i, y_j, d); - if (use_sel && !sel->is_member(j, 1.f - ip)) { + if (use_sel && !sel->is_member(j, ip)) { continue; } resi.add_result(ip, j); From ef19a285cf929212762102de49f67b13355ca953 Mon Sep 17 00:00:00 2001 From: Jaepil Jeong Date: Sat, 13 Jul 2024 16:02:13 +0900 Subject: [PATCH 04/10] Change IDSelector interface --- faiss/IndexIDMap.h | 7 ++-- faiss/impl/IDSelector.h | 67 +++++++++++++++++++++++++------------- faiss/impl/ResultHandler.h | 5 +-- faiss/utils/distances.cpp | 4 +-- 4 files changed, 53 insertions(+), 30 deletions(-) diff --git a/faiss/IndexIDMap.h b/faiss/IndexIDMap.h index a145c1bdfe..0047dc4b26 100644 --- a/faiss/IndexIDMap.h +++ b/faiss/IndexIDMap.h @@ -119,9 +119,10 @@ struct IDSelectorTranslated : IDSelector { IDSelectorTranslated(IndexIDMap& index_idmap, const IDSelector* sel) : id_map(index_idmap.id_map), sel(sel) {} - bool is_member( - idx_t id, - std::optional d = std::nullopt) const override { + bool is_member(idx_t id) const override { + return sel->is_member(id_map[id]); + } + bool is_member(idx_t id, std::optional d) const override { return sel->is_member(id_map[id], d); } }; diff --git a/faiss/impl/IDSelector.h b/faiss/impl/IDSelector.h index 1ba545d1b2..8ff767c6d7 100644 --- a/faiss/impl/IDSelector.h +++ b/faiss/impl/IDSelector.h @@ -20,9 +20,13 @@ namespace faiss { /** Encapsulates a set of ids to handle. */ struct IDSelector { - virtual bool is_member( - idx_t id, - std::optional d = std::nullopt) const = 0; + virtual bool is_member(idx_t id) const = 0; + virtual bool is_member(idx_t id, std::optional d) const { + if (!d.has_value()) { + return is_member(id); + } + return true; + } virtual ~IDSelector() {} }; @@ -36,7 +40,10 @@ struct IDSelectorRange : IDSelector { IDSelectorRange(idx_t imin, idx_t imax, bool assume_sorted = false); - bool is_member(idx_t id, std::optional d = std::nullopt) const final; + bool is_member(idx_t id) const final { + return is_member(id, std::nullopt); + } + bool is_member(idx_t id, std::optional d) const final; /// for sorted ids, find the range of list indices where the valid ids are /// stored @@ -65,7 +72,10 @@ struct IDSelectorArray : IDSelector { * IDSelectorArray's lifetime */ IDSelectorArray(size_t n, const idx_t* ids); - bool is_member(idx_t id, std::optional d = std::nullopt) const final; + bool is_member(idx_t id) const final { + return is_member(id, std::nullopt); + } + bool is_member(idx_t id, std::optional d) const final; ~IDSelectorArray() override {} }; @@ -95,7 +105,10 @@ struct IDSelectorBatch : IDSelector { * construction */ IDSelectorBatch(size_t n, const idx_t* indices); - bool is_member(idx_t id, std::optional d = std::nullopt) const final; + bool is_member(idx_t id) const final { + return is_member(id, std::nullopt); + } + bool is_member(idx_t id, std::optional d) const final; ~IDSelectorBatch() override {} }; @@ -112,7 +125,10 @@ struct IDSelectorBitmap : IDSelector { * (i%8) of bitmap[floor(i / 8)] is 1. */ IDSelectorBitmap(size_t n, const uint8_t* bitmap); - bool is_member(idx_t id, std::optional d = std::nullopt) const final; + bool is_member(idx_t id) const final { + return is_member(id, std::nullopt); + } + bool is_member(idx_t id, std::optional d) const final; ~IDSelectorBitmap() override {} }; @@ -120,20 +136,22 @@ struct IDSelectorBitmap : IDSelector { struct IDSelectorNot : IDSelector { const IDSelector* sel; IDSelectorNot(const IDSelector* sel) : sel(sel) {} - bool is_member( - idx_t id, - std::optional d = std::nullopt) const final { - (void)d; + bool is_member(idx_t id) const final { return !sel->is_member(id); } + bool is_member(idx_t id, std::optional d) const final { + return !sel->is_member(id, d); + } virtual ~IDSelectorNot() {} }; /// selects all entries (useful for benchmarking) struct IDSelectorAll : IDSelector { - bool is_member( - idx_t id, - std::optional d = std::nullopt) const final { + bool is_member(idx_t id) const final { + (void)id; + return true; + } + bool is_member(idx_t id, std::optional d) const final { (void)id; (void)d; return true; @@ -148,9 +166,10 @@ struct IDSelectorAnd : IDSelector { const IDSelector* rhs; IDSelectorAnd(const IDSelector* lhs, const IDSelector* rhs) : lhs(lhs), rhs(rhs) {} - bool is_member( - idx_t id, - std::optional d = std::nullopt) const final { + bool is_member(idx_t id) const final { + return lhs->is_member(id) && rhs->is_member(id); + } + bool is_member(idx_t id, std::optional d) const final { return lhs->is_member(id, d) && rhs->is_member(id, d); } virtual ~IDSelectorAnd() {} @@ -163,9 +182,10 @@ struct IDSelectorOr : IDSelector { const IDSelector* rhs; IDSelectorOr(const IDSelector* lhs, const IDSelector* rhs) : lhs(lhs), rhs(rhs) {} - bool is_member( - idx_t id, - std::optional d = std::nullopt) const final { + bool is_member(idx_t id) const final { + return lhs->is_member(id) || rhs->is_member(id); + } + bool is_member(idx_t id, std::optional d) const final { return lhs->is_member(id, d) || rhs->is_member(id, d); } virtual ~IDSelectorOr() {} @@ -178,9 +198,10 @@ struct IDSelectorXOr : IDSelector { const IDSelector* rhs; IDSelectorXOr(const IDSelector* lhs, const IDSelector* rhs) : lhs(lhs), rhs(rhs) {} - bool is_member( - idx_t id, - std::optional d = std::nullopt) const final { + bool is_member(idx_t id) const final { + return lhs->is_member(id) ^ rhs->is_member(id); + } + bool is_member(idx_t id, std::optional d) const final { return lhs->is_member(id, d) ^ rhs->is_member(id, d); } virtual ~IDSelectorXOr() {} diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h index 3116eb24df..f65025fee4 100644 --- a/faiss/impl/ResultHandler.h +++ b/faiss/impl/ResultHandler.h @@ -19,6 +19,7 @@ #include #include +#include namespace faiss { @@ -62,8 +63,8 @@ struct BlockResultHandler { virtual ~BlockResultHandler() {} - bool is_in_selection(idx_t i) const { - return !use_sel || sel->is_member(i); + bool is_in_selection(idx_t i, std::optional d = std::nullopt) const { + return !use_sel || sel->is_member(i, d); } }; diff --git a/faiss/utils/distances.cpp b/faiss/utils/distances.cpp index a114afaa00..99d7ab413a 100644 --- a/faiss/utils/distances.cpp +++ b/faiss/utils/distances.cpp @@ -157,7 +157,7 @@ void exhaustive_inner_product_seq( continue; } float ip = fvec_inner_product(x_i, y_j, d); - if (use_sel && !sel->is_member(j, ip)) { + if (!res.is_in_selection(j, ip)) { continue; } resi.add_result(ip, j); @@ -192,7 +192,7 @@ void exhaustive_L2sqr_seq( continue; } float disij = fvec_L2sqr(x_i, y_j, d); - if (use_sel && !sel->is_member(j, disij)) { + if (!res.is_in_selection(j, disij)) { continue; } resi.add_result(disij, j); From 8d9f3eda140e8295aa3c1435c1ff2250b7ba7df4 Mon Sep 17 00:00:00 2001 From: Jaepil Jeong Date: Sat, 13 Jul 2024 16:16:47 +0900 Subject: [PATCH 05/10] Change the default behavior of is_member(idx, d) function --- faiss/impl/IDSelector.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/faiss/impl/IDSelector.h b/faiss/impl/IDSelector.h index 8ff767c6d7..ea7ecebb8d 100644 --- a/faiss/impl/IDSelector.h +++ b/faiss/impl/IDSelector.h @@ -22,10 +22,10 @@ namespace faiss { struct IDSelector { virtual bool is_member(idx_t id) const = 0; virtual bool is_member(idx_t id, std::optional d) const { - if (!d.has_value()) { - return is_member(id); - } - return true; + (void)d; + // default implementation ignores the distance for backward + // compatibility + return is_member(id); } virtual ~IDSelector() {} }; From 515fef31889b6098467379d3670a9d594b32ebec Mon Sep 17 00:00:00 2001 From: Jaepil Jeong Date: Sat, 13 Jul 2024 16:31:35 +0900 Subject: [PATCH 06/10] Minor change --- faiss/impl/IDSelector.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/faiss/impl/IDSelector.h b/faiss/impl/IDSelector.h index ea7ecebb8d..c0926d68d2 100644 --- a/faiss/impl/IDSelector.h +++ b/faiss/impl/IDSelector.h @@ -23,8 +23,8 @@ struct IDSelector { virtual bool is_member(idx_t id) const = 0; virtual bool is_member(idx_t id, std::optional d) const { (void)d; - // default implementation ignores the distance for backward - // compatibility + /// default implementation ignores the distance for backward + /// compatibility return is_member(id); } virtual ~IDSelector() {} From ae3b977d6b388b2bafd22e8c3a35763affe05ab7 Mon Sep 17 00:00:00 2001 From: Jaepil Jeong Date: Sat, 13 Jul 2024 16:39:17 +0900 Subject: [PATCH 07/10] Minor change --- faiss/impl/IDSelector.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/faiss/impl/IDSelector.h b/faiss/impl/IDSelector.h index c0926d68d2..1e3f784f85 100644 --- a/faiss/impl/IDSelector.h +++ b/faiss/impl/IDSelector.h @@ -23,8 +23,6 @@ struct IDSelector { virtual bool is_member(idx_t id) const = 0; virtual bool is_member(idx_t id, std::optional d) const { (void)d; - /// default implementation ignores the distance for backward - /// compatibility return is_member(id); } virtual ~IDSelector() {} From cc2313e969a928f9a68c12a978053b6f4cc8b002 Mon Sep 17 00:00:00 2001 From: Jaepil Jeong Date: Mon, 15 Jul 2024 16:55:07 +0900 Subject: [PATCH 08/10] Implement is_member() for C, Python, and SIMD --- c_api/impl/AuxIndexStructures_c.cpp | 7 +++++++ c_api/impl/AuxIndexStructures_c.h | 4 ++++ faiss/impl/simd_result_handlers.h | 16 ++++++++-------- faiss/python/python_callbacks.cpp | 19 +++++++++++++++++++ faiss/python/python_callbacks.h | 1 + 5 files changed, 39 insertions(+), 8 deletions(-) diff --git a/c_api/impl/AuxIndexStructures_c.cpp b/c_api/impl/AuxIndexStructures_c.cpp index c19c1b63be..ff78cc3525 100644 --- a/c_api/impl/AuxIndexStructures_c.cpp +++ b/c_api/impl/AuxIndexStructures_c.cpp @@ -88,6 +88,13 @@ int faiss_IDSelector_is_member(const FaissIDSelector* sel, idx_t id) { return reinterpret_cast(sel)->is_member(id); } +int faiss_IDSelector_is_member_with_dist( + const FaissIDSelector* sel, + idx_t id, + float dist) { + return reinterpret_cast(sel)->is_member(id, dist); +} + DEFINE_DESTRUCTOR(IDSelectorRange) DEFINE_GETTER(IDSelectorRange, idx_t, imin) diff --git a/c_api/impl/AuxIndexStructures_c.h b/c_api/impl/AuxIndexStructures_c.h index dba3026980..cd1306b7fa 100644 --- a/c_api/impl/AuxIndexStructures_c.h +++ b/c_api/impl/AuxIndexStructures_c.h @@ -53,6 +53,10 @@ FAISS_DECLARE_CLASS(IDSelector) FAISS_DECLARE_DESTRUCTOR(IDSelector) int faiss_IDSelector_is_member(const FaissIDSelector* sel, idx_t id); +int faiss_IDSelector_is_member_with_dist( + const FaissIDSelector* sel, + idx_t id, + float dist); /** remove ids between [imni, imax) */ FAISS_DECLARE_CLASS(IDSelectorRange) diff --git a/faiss/impl/simd_result_handlers.h b/faiss/impl/simd_result_handlers.h index 2fa18fa340..ee5d755d65 100644 --- a/faiss/impl/simd_result_handlers.h +++ b/faiss/impl/simd_result_handlers.h @@ -271,8 +271,8 @@ struct SingleResultHandler : ResultHandlerCompare { int j = __builtin_ctz(lt_mask); auto real_idx = this->adjust_id(b, j); lt_mask -= 1 << j; - if (this->sel->is_member(real_idx)) { - T d = d32tab[j]; + T d = d32tab[j]; + if (this->sel->is_member(real_idx, d)) { if (C::cmp(idis[q], d)) { idis[q] = d; ids[q] = real_idx; @@ -367,8 +367,8 @@ struct HeapHandler : ResultHandlerCompare { int j = __builtin_ctz(lt_mask); auto real_idx = this->adjust_id(b, j); lt_mask -= 1 << j; - if (this->sel->is_member(real_idx)) { - T dis = d32tab[j]; + T dis = d32tab[j]; + if (this->sel->is_member(real_idx, dis)) { if (C::cmp(heap_dis[0], dis)) { heap_replace_top( k, heap_dis, heap_ids, dis, real_idx); @@ -479,8 +479,8 @@ struct ReservoirHandler : ResultHandlerCompare { int j = __builtin_ctz(lt_mask); auto real_idx = this->adjust_id(b, j); lt_mask -= 1 << j; - if (this->sel->is_member(real_idx)) { - T dis = d32tab[j]; + T dis = d32tab[j]; + if (this->sel->is_member(real_idx, dis)) { res.add(dis, real_idx); } } @@ -602,8 +602,8 @@ struct RangeHandler : ResultHandlerCompare { lt_mask -= 1 << j; auto real_idx = this->adjust_id(b, j); - if (this->sel->is_member(real_idx)) { - T dis = d32tab[j]; + T dis = d32tab[j]; + if (this->sel->is_member(real_idx, dis)) { n_per_query[q]++; triplets.push_back({idx_t(q + q0), real_idx, dis}); } diff --git a/faiss/python/python_callbacks.cpp b/faiss/python/python_callbacks.cpp index 06b5c18cfc..133ae3c09b 100644 --- a/faiss/python/python_callbacks.cpp +++ b/faiss/python/python_callbacks.cpp @@ -130,6 +130,25 @@ bool PyCallbackIDSelector::is_member(faiss::idx_t id) const { return b; } +bool PyCallbackIDSelector::is_member( + faiss::idx_t id, + std::optional d) const { + if (!d.has_value()) { + return is_member(id); + } + + FAISS_THROW_IF_NOT((id >> 32) == 0); + PyThreadLock gil; + PyObject* result = PyObject_CallFunction( + callback, "(nf)", int(id), d.value()); + if (result == nullptr) { + FAISS_THROW_MSG("propagate py error"); + } + bool b = PyObject_IsTrue(result); + Py_DECREF(result); + return b; +} + PyCallbackIDSelector::~PyCallbackIDSelector() { PyThreadLock gil; Py_DECREF(callback); diff --git a/faiss/python/python_callbacks.h b/faiss/python/python_callbacks.h index 421239bd0f..f3af0fa1c0 100644 --- a/faiss/python/python_callbacks.h +++ b/faiss/python/python_callbacks.h @@ -55,6 +55,7 @@ struct PyCallbackIDSelector : faiss::IDSelector { explicit PyCallbackIDSelector(PyObject* callback); bool is_member(faiss::idx_t id) const override; + bool is_member(faiss::idx_t id, std::optional d) const override; ~PyCallbackIDSelector() override; }; From 9894759d2df192bcbd16eb0d14e2b844c7c18147 Mon Sep 17 00:00:00 2001 From: Jaepil Jeong Date: Tue, 16 Jul 2024 11:33:07 +0900 Subject: [PATCH 09/10] Make clang-format happy --- faiss/python/python_callbacks.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/faiss/python/python_callbacks.cpp b/faiss/python/python_callbacks.cpp index 133ae3c09b..62d4a31382 100644 --- a/faiss/python/python_callbacks.cpp +++ b/faiss/python/python_callbacks.cpp @@ -130,17 +130,16 @@ bool PyCallbackIDSelector::is_member(faiss::idx_t id) const { return b; } -bool PyCallbackIDSelector::is_member( - faiss::idx_t id, - std::optional d) const { +bool PyCallbackIDSelector::is_member(faiss::idx_t id, std::optional d) + const { if (!d.has_value()) { return is_member(id); } FAISS_THROW_IF_NOT((id >> 32) == 0); PyThreadLock gil; - PyObject* result = PyObject_CallFunction( - callback, "(nf)", int(id), d.value()); + PyObject* result = + PyObject_CallFunction(callback, "(nf)", int(id), d.value()); if (result == nullptr) { FAISS_THROW_MSG("propagate py error"); } From ac45fc6dcc26a6a05810f831c3706d455f032e8b Mon Sep 17 00:00:00 2001 From: Jaepil Jeong Date: Tue, 16 Jul 2024 13:40:23 +0900 Subject: [PATCH 10/10] Update IDSelector --- faiss/impl/IDSelector.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/faiss/impl/IDSelector.h b/faiss/impl/IDSelector.h index 1e3f784f85..de8b6aa185 100644 --- a/faiss/impl/IDSelector.h +++ b/faiss/impl/IDSelector.h @@ -20,7 +20,10 @@ namespace faiss { /** Encapsulates a set of ids to handle. */ struct IDSelector { - virtual bool is_member(idx_t id) const = 0; + virtual bool is_member(idx_t id) const { + (void)id; + return true; + } virtual bool is_member(idx_t id, std::optional d) const { (void)d; return is_member(id);