From c9da261e070439e8bc46d16f1c7195f58d41a6fb Mon Sep 17 00:00:00 2001 From: Kaival Parikh Date: Tue, 4 Feb 2025 18:48:59 +0000 Subject: [PATCH 1/2] Handle plain SearchParameters in HNSW searches --- faiss/IndexHNSW.cpp | 24 ++++++++---------------- faiss/impl/HNSW.cpp | 41 ++++++++++++++++++++++++++++++----------- faiss/impl/HNSW.h | 6 +++--- 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 4cc58d211c..70d48d81ea 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -237,19 +237,18 @@ void hnsw_search( idx_t n, const float* x, BlockResultHandler& bres, - const SearchParameters* params_in) { + const SearchParameters* params) { FAISS_THROW_IF_NOT_MSG( index->storage, "No storage index, please use IndexHNSWFlat (or variants) " "instead of IndexHNSW directly"); - const SearchParametersHNSW* params = nullptr; const HNSW& hnsw = index->hnsw; int efSearch = hnsw.efSearch; - if (params_in) { - params = dynamic_cast(params_in); - FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); - efSearch = params->efSearch; + if (params) { + if (const SearchParametersHNSW* hnsw_params = dynamic_cast(params)) { + efSearch = hnsw_params->efSearch; + } } size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0; @@ -294,13 +293,13 @@ void IndexHNSW::search( idx_t k, float* distances, idx_t* labels, - const SearchParameters* params_in) const { + const SearchParameters* params) const { FAISS_THROW_IF_NOT(k > 0); using RH = HeapBlockResultHandler; RH bres(n, distances, labels, k); - hnsw_search(this, n, x, bres, params_in); + hnsw_search(this, n, x, bres, params); if (is_similarity_metric(this->metric_type)) { // we need to revert the negated distances @@ -408,17 +407,10 @@ void IndexHNSW::search_level_0( idx_t* labels, int nprobe, int search_type, - const SearchParameters* params_in) const { + const SearchParameters* params) const { FAISS_THROW_IF_NOT(k > 0); FAISS_THROW_IF_NOT(nprobe > 0); - const SearchParametersHNSW* params = nullptr; - - if (params_in) { - params = dynamic_cast(params_in); - FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); - } - storage_idx_t ntotal = hnsw.levels.size(); using RH = HeapBlockResultHandler; diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 09b10e2b97..f5d66ad24d 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -590,15 +590,21 @@ int search_from_candidates( HNSWStats& stats, int level, int nres_in, - const SearchParametersHNSW* params) { + const SearchParameters* params) { int nres = nres_in; int ndis = 0; // can be overridden by search params - bool do_dis_check = params ? params->check_relative_distance - : hnsw.check_relative_distance; - int efSearch = params ? params->efSearch : hnsw.efSearch; - const IDSelector* sel = params ? params->sel : nullptr; + bool do_dis_check = hnsw.check_relative_distance; + int efSearch = hnsw.efSearch; + const IDSelector* sel = nullptr; + if (params) { + if (const SearchParametersHNSW* hnsw_params = dynamic_cast(params)) { + do_dis_check = hnsw_params->check_relative_distance; + efSearch = hnsw_params->efSearch; + } + sel = params->sel; + } C::T threshold = res.threshold; for (int i = 0; i < candidates.size(); i++) { @@ -920,15 +926,21 @@ HNSWStats HNSW::search( DistanceComputer& qdis, ResultHandler& res, VisitedTable& vt, - const SearchParametersHNSW* params) const { + const SearchParameters* params) const { HNSWStats stats; if (entry_point == -1) { return stats; } int k = extract_k_from_ResultHandler(res); - bool bounded_queue = - params ? params->bounded_queue : this->search_bounded_queue; + bool bounded_queue = this->search_bounded_queue; + int efSearch = this->efSearch; + if (params) { + if (const SearchParametersHNSW* hnsw_params = dynamic_cast(params)) { + bounded_queue = hnsw_params->bounded_queue; + efSearch = hnsw_params->efSearch; + } + } // greedy search on upper levels storage_idx_t nearest = entry_point; @@ -940,7 +952,7 @@ HNSWStats HNSW::search( stats.combine(local_stats); } - int ef = std::max(params ? params->efSearch : efSearch, k); + int ef = std::max(efSearch, k); if (bounded_queue) { // this is the most common branch MinimaxHeap candidates(ef); @@ -980,9 +992,16 @@ void HNSW::search_level_0( int search_type, HNSWStats& search_stats, VisitedTable& vt, - const SearchParametersHNSW* params) const { + const SearchParameters* params) const { const HNSW& hnsw = *this; - auto efSearch = params ? params->efSearch : hnsw.efSearch; + + auto efSearch = hnsw.efSearch; + if (params) { + if (const SearchParametersHNSW* hnsw_params = dynamic_cast(params)) { + efSearch = hnsw_params->efSearch; + } + } + int k = extract_k_from_ResultHandler(res); if (search_type == 1) { diff --git a/faiss/impl/HNSW.h b/faiss/impl/HNSW.h index aad26b1eda..f79f1c1199 100644 --- a/faiss/impl/HNSW.h +++ b/faiss/impl/HNSW.h @@ -200,7 +200,7 @@ struct HNSW { DistanceComputer& qdis, ResultHandler& res, VisitedTable& vt, - const SearchParametersHNSW* params = nullptr) const; + const SearchParameters* params = nullptr) const; /// search only in level 0 from a given vertex void search_level_0( @@ -212,7 +212,7 @@ struct HNSW { int search_type, HNSWStats& search_stats, VisitedTable& vt, - const SearchParametersHNSW* params = nullptr) const; + const SearchParameters* params = nullptr) const; void reset(); @@ -264,7 +264,7 @@ int search_from_candidates( HNSWStats& stats, int level, int nres_in = 0, - const SearchParametersHNSW* params = nullptr); + const SearchParameters* params = nullptr); HNSWStats greedy_update_nearest( const HNSW& hnsw, From dd57e0626f36d7c985d59ab941046842c168de18 Mon Sep 17 00:00:00 2001 From: Kaival Parikh Date: Thu, 6 Feb 2025 06:41:53 +0000 Subject: [PATCH 2/2] clang-format --- faiss/IndexHNSW.cpp | 5 +++-- faiss/impl/HNSW.cpp | 27 +++++++++++++++------------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 70d48d81ea..d6739b69e8 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -246,8 +246,9 @@ void hnsw_search( int efSearch = hnsw.efSearch; if (params) { - if (const SearchParametersHNSW* hnsw_params = dynamic_cast(params)) { - efSearch = hnsw_params->efSearch; + if (const SearchParametersHNSW* hnsw_params = + dynamic_cast(params)) { + efSearch = hnsw_params->efSearch; } } size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0; diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index f5d66ad24d..ece0281221 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -599,11 +599,12 @@ int search_from_candidates( int efSearch = hnsw.efSearch; const IDSelector* sel = nullptr; if (params) { - if (const SearchParametersHNSW* hnsw_params = dynamic_cast(params)) { - do_dis_check = hnsw_params->check_relative_distance; - efSearch = hnsw_params->efSearch; - } - sel = params->sel; + if (const SearchParametersHNSW* hnsw_params = + dynamic_cast(params)) { + do_dis_check = hnsw_params->check_relative_distance; + efSearch = hnsw_params->efSearch; + } + sel = params->sel; } C::T threshold = res.threshold; @@ -936,10 +937,11 @@ HNSWStats HNSW::search( bool bounded_queue = this->search_bounded_queue; int efSearch = this->efSearch; if (params) { - if (const SearchParametersHNSW* hnsw_params = dynamic_cast(params)) { - bounded_queue = hnsw_params->bounded_queue; - efSearch = hnsw_params->efSearch; - } + if (const SearchParametersHNSW* hnsw_params = + dynamic_cast(params)) { + bounded_queue = hnsw_params->bounded_queue; + efSearch = hnsw_params->efSearch; + } } // greedy search on upper levels @@ -997,9 +999,10 @@ void HNSW::search_level_0( auto efSearch = hnsw.efSearch; if (params) { - if (const SearchParametersHNSW* hnsw_params = dynamic_cast(params)) { - efSearch = hnsw_params->efSearch; - } + if (const SearchParametersHNSW* hnsw_params = + dynamic_cast(params)) { + efSearch = hnsw_params->efSearch; + } } int k = extract_k_from_ResultHandler(res);