Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,19 +237,19 @@ void hnsw_search(
idx_t n,
const float* x,
BlockResultHandler& bres,
const SearchParameters* params_in) {
const SearchParameters* params) {
FAISS_THROW_IF_NOT_MSG(
index->storage,
"No storage index, please use IndexHNSWFlat (or variants) "
"instead of IndexHNSW directly");
const SearchParametersHNSW* params = nullptr;
const HNSW& hnsw = index->hnsw;

int efSearch = hnsw.efSearch;
if (params_in) {
params = dynamic_cast<const SearchParametersHNSW*>(params_in);
FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
efSearch = params->efSearch;
if (params) {
if (const SearchParametersHNSW* hnsw_params =
dynamic_cast<const SearchParametersHNSW*>(params)) {
efSearch = hnsw_params->efSearch;
}
}
size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;

Expand Down Expand Up @@ -294,13 +294,13 @@ void IndexHNSW::search(
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params_in) const {
const SearchParameters* params) const {
FAISS_THROW_IF_NOT(k > 0);

using RH = HeapBlockResultHandler<HNSW::C>;
RH bres(n, distances, labels, k);

hnsw_search(this, n, x, bres, params_in);
hnsw_search(this, n, x, bres, params);

if (is_similarity_metric(this->metric_type)) {
// we need to revert the negated distances
Expand Down Expand Up @@ -408,17 +408,10 @@ void IndexHNSW::search_level_0(
idx_t* labels,
int nprobe,
int search_type,
const SearchParameters* params_in) const {
const SearchParameters* params) const {
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(nprobe > 0);

const SearchParametersHNSW* params = nullptr;

if (params_in) {
params = dynamic_cast<const SearchParametersHNSW*>(params_in);
FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
}

storage_idx_t ntotal = hnsw.levels.size();

using RH = HeapBlockResultHandler<HNSW::C>;
Expand Down
44 changes: 33 additions & 11 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,15 +590,22 @@ int search_from_candidates(
HNSWStats& stats,
int level,
int nres_in,
const SearchParametersHNSW* params) {
const SearchParameters* params) {
int nres = nres_in;
int ndis = 0;

// can be overridden by search params
bool do_dis_check = params ? params->check_relative_distance
: hnsw.check_relative_distance;
int efSearch = params ? params->efSearch : hnsw.efSearch;
const IDSelector* sel = params ? params->sel : nullptr;
bool do_dis_check = hnsw.check_relative_distance;
int efSearch = hnsw.efSearch;
const IDSelector* sel = nullptr;
if (params) {
if (const SearchParametersHNSW* hnsw_params =
dynamic_cast<const SearchParametersHNSW*>(params)) {
do_dis_check = hnsw_params->check_relative_distance;
efSearch = hnsw_params->efSearch;
}
sel = params->sel;
}

C::T threshold = res.threshold;
for (int i = 0; i < candidates.size(); i++) {
Expand Down Expand Up @@ -920,15 +927,22 @@ HNSWStats HNSW::search(
DistanceComputer& qdis,
ResultHandler<C>& res,
VisitedTable& vt,
const SearchParametersHNSW* params) const {
const SearchParameters* params) const {
HNSWStats stats;
if (entry_point == -1) {
return stats;
}
int k = extract_k_from_ResultHandler(res);

bool bounded_queue =
params ? params->bounded_queue : this->search_bounded_queue;
bool bounded_queue = this->search_bounded_queue;
int efSearch = this->efSearch;
if (params) {
if (const SearchParametersHNSW* hnsw_params =
dynamic_cast<const SearchParametersHNSW*>(params)) {
bounded_queue = hnsw_params->bounded_queue;
efSearch = hnsw_params->efSearch;
}
}

// greedy search on upper levels
storage_idx_t nearest = entry_point;
Expand All @@ -940,7 +954,7 @@ HNSWStats HNSW::search(
stats.combine(local_stats);
}

int ef = std::max(params ? params->efSearch : efSearch, k);
int ef = std::max(efSearch, k);
if (bounded_queue) { // this is the most common branch
MinimaxHeap candidates(ef);

Expand Down Expand Up @@ -980,9 +994,17 @@ void HNSW::search_level_0(
int search_type,
HNSWStats& search_stats,
VisitedTable& vt,
const SearchParametersHNSW* params) const {
const SearchParameters* params) const {
const HNSW& hnsw = *this;
auto efSearch = params ? params->efSearch : hnsw.efSearch;

auto efSearch = hnsw.efSearch;
if (params) {
if (const SearchParametersHNSW* hnsw_params =
dynamic_cast<const SearchParametersHNSW*>(params)) {
efSearch = hnsw_params->efSearch;
}
}

int k = extract_k_from_ResultHandler(res);

if (search_type == 1) {
Expand Down
6 changes: 3 additions & 3 deletions faiss/impl/HNSW.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ struct HNSW {
DistanceComputer& qdis,
ResultHandler<C>& res,
VisitedTable& vt,
const SearchParametersHNSW* params = nullptr) const;
const SearchParameters* params = nullptr) const;

/// search only in level 0 from a given vertex
void search_level_0(
Expand All @@ -212,7 +212,7 @@ struct HNSW {
int search_type,
HNSWStats& search_stats,
VisitedTable& vt,
const SearchParametersHNSW* params = nullptr) const;
const SearchParameters* params = nullptr) const;

void reset();

Expand Down Expand Up @@ -264,7 +264,7 @@ int search_from_candidates(
HNSWStats& stats,
int level,
int nres_in = 0,
const SearchParametersHNSW* params = nullptr);
const SearchParameters* params = nullptr);

HNSWStats greedy_update_nearest(
const HNSW& hnsw,
Expand Down
Loading