diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 4cc58d211c..d6739b69e8 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -237,19 +237,19 @@ void hnsw_search( idx_t n, const float* x, BlockResultHandler& bres, - const SearchParameters* params_in) { + const SearchParameters* params) { FAISS_THROW_IF_NOT_MSG( index->storage, "No storage index, please use IndexHNSWFlat (or variants) " "instead of IndexHNSW directly"); - const SearchParametersHNSW* params = nullptr; const HNSW& hnsw = index->hnsw; int efSearch = hnsw.efSearch; - if (params_in) { - params = dynamic_cast(params_in); - FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); - efSearch = params->efSearch; + if (params) { + if (const SearchParametersHNSW* hnsw_params = + dynamic_cast(params)) { + efSearch = hnsw_params->efSearch; + } } size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0; @@ -294,13 +294,13 @@ void IndexHNSW::search( idx_t k, float* distances, idx_t* labels, - const SearchParameters* params_in) const { + const SearchParameters* params) const { FAISS_THROW_IF_NOT(k > 0); using RH = HeapBlockResultHandler; RH bres(n, distances, labels, k); - hnsw_search(this, n, x, bres, params_in); + hnsw_search(this, n, x, bres, params); if (is_similarity_metric(this->metric_type)) { // we need to revert the negated distances @@ -408,17 +408,10 @@ void IndexHNSW::search_level_0( idx_t* labels, int nprobe, int search_type, - const SearchParameters* params_in) const { + const SearchParameters* params) const { FAISS_THROW_IF_NOT(k > 0); FAISS_THROW_IF_NOT(nprobe > 0); - const SearchParametersHNSW* params = nullptr; - - if (params_in) { - params = dynamic_cast(params_in); - FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); - } - storage_idx_t ntotal = hnsw.levels.size(); using RH = HeapBlockResultHandler; diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 09b10e2b97..ece0281221 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -590,15 +590,22 @@ int search_from_candidates( HNSWStats& stats, int level, int nres_in, - const SearchParametersHNSW* params) { + const SearchParameters* params) { int nres = nres_in; int ndis = 0; // can be overridden by search params - bool do_dis_check = params ? params->check_relative_distance - : hnsw.check_relative_distance; - int efSearch = params ? params->efSearch : hnsw.efSearch; - const IDSelector* sel = params ? params->sel : nullptr; + bool do_dis_check = hnsw.check_relative_distance; + int efSearch = hnsw.efSearch; + const IDSelector* sel = nullptr; + if (params) { + if (const SearchParametersHNSW* hnsw_params = + dynamic_cast(params)) { + do_dis_check = hnsw_params->check_relative_distance; + efSearch = hnsw_params->efSearch; + } + sel = params->sel; + } C::T threshold = res.threshold; for (int i = 0; i < candidates.size(); i++) { @@ -920,15 +927,22 @@ HNSWStats HNSW::search( DistanceComputer& qdis, ResultHandler& res, VisitedTable& vt, - const SearchParametersHNSW* params) const { + const SearchParameters* params) const { HNSWStats stats; if (entry_point == -1) { return stats; } int k = extract_k_from_ResultHandler(res); - bool bounded_queue = - params ? params->bounded_queue : this->search_bounded_queue; + bool bounded_queue = this->search_bounded_queue; + int efSearch = this->efSearch; + if (params) { + if (const SearchParametersHNSW* hnsw_params = + dynamic_cast(params)) { + bounded_queue = hnsw_params->bounded_queue; + efSearch = hnsw_params->efSearch; + } + } // greedy search on upper levels storage_idx_t nearest = entry_point; @@ -940,7 +954,7 @@ HNSWStats HNSW::search( stats.combine(local_stats); } - int ef = std::max(params ? params->efSearch : efSearch, k); + int ef = std::max(efSearch, k); if (bounded_queue) { // this is the most common branch MinimaxHeap candidates(ef); @@ -980,9 +994,17 @@ void HNSW::search_level_0( int search_type, HNSWStats& search_stats, VisitedTable& vt, - const SearchParametersHNSW* params) const { + const SearchParameters* params) const { const HNSW& hnsw = *this; - auto efSearch = params ? params->efSearch : hnsw.efSearch; + + auto efSearch = hnsw.efSearch; + if (params) { + if (const SearchParametersHNSW* hnsw_params = + dynamic_cast(params)) { + efSearch = hnsw_params->efSearch; + } + } + int k = extract_k_from_ResultHandler(res); if (search_type == 1) { diff --git a/faiss/impl/HNSW.h b/faiss/impl/HNSW.h index aad26b1eda..f79f1c1199 100644 --- a/faiss/impl/HNSW.h +++ b/faiss/impl/HNSW.h @@ -200,7 +200,7 @@ struct HNSW { DistanceComputer& qdis, ResultHandler& res, VisitedTable& vt, - const SearchParametersHNSW* params = nullptr) const; + const SearchParameters* params = nullptr) const; /// search only in level 0 from a given vertex void search_level_0( @@ -212,7 +212,7 @@ struct HNSW { int search_type, HNSWStats& search_stats, VisitedTable& vt, - const SearchParametersHNSW* params = nullptr) const; + const SearchParameters* params = nullptr) const; void reset(); @@ -264,7 +264,7 @@ int search_from_candidates( HNSWStats& stats, int level, int nres_in = 0, - const SearchParametersHNSW* params = nullptr); + const SearchParameters* params = nullptr); HNSWStats greedy_update_nearest( const HNSW& hnsw,