Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions thirdparty/faiss/faiss/IndexFlat.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ namespace faiss {

/** Index that stores the full vectors and performs exhaustive search */
struct IndexFlat : IndexFlatCodes {
/// database vectors, size ntotal * d
std::vector<float> xb;

explicit IndexFlat(
idx_t d, ///< dimensionality of the input vectors
MetricType metric = METRIC_L2,
Expand Down
29 changes: 20 additions & 9 deletions thirdparty/faiss/faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,23 +275,23 @@ void hnsw_search(
FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
efSearch = params->efSearch;
}
size_t n1 = 0, n2 = 0, ndis = 0;
size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;

idx_t check_period = InterruptCallback::get_period_hint(
hnsw.max_level * index->d * efSearch);

for (idx_t i0 = 0; i0 < n; i0 += check_period) {
idx_t i1 = std::min(i0 + check_period, n);

#pragma omp parallel
#pragma omp parallel if (i1 - i0 > 1)
{
VisitedTable vt(index->ntotal);
typename BlockResultHandler::SingleResultHandler res(bres);

std::unique_ptr<DistanceComputer> dis(
storage_distance_computer(index->storage));

#pragma omp for reduction(+ : n1, n2, ndis) schedule(guided)
#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided)
for (idx_t i = i0; i < i1; i++) {
res.begin(i);
dis->set_query(x + i * index->d);
Expand All @@ -300,13 +300,14 @@ void hnsw_search(
n1 += stats.n1;
n2 += stats.n2;
ndis += stats.ndis;
nhops += stats.nhops;
res.end();
}
}
InterruptCallback::check();
}

hnsw_stats.combine({n1, n2, ndis});
hnsw_stats.combine({n1, n2, ndis, nhops});
}

} // anonymous namespace
Expand Down Expand Up @@ -632,6 +633,10 @@ void IndexHNSW::permute_entries(const idx_t* perm) {
hnsw.permute_entries(perm);
}

DistanceComputer* IndexHNSW::get_distance_computer() const {
return storage->get_distance_computer();
}

/**************************************************************
* IndexHNSWFlat implementation
**************************************************************/
Expand All @@ -655,8 +660,13 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)

IndexHNSWPQ::IndexHNSWPQ() = default;

IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits)
: IndexHNSW(new IndexPQ(d, pq_m, pq_nbits), M) {
IndexHNSWPQ::IndexHNSWPQ(
int d,
int pq_m,
int M,
int pq_nbits,
MetricType metric)
: IndexHNSW(new IndexPQ(d, pq_m, pq_nbits, metric), M) {
own_fields = true;
is_trained = false;
}
Expand Down Expand Up @@ -782,7 +792,7 @@ void IndexHNSW2Level::search(
IndexHNSW::search(n, x, k, distances, labels);

} else { // "mixed" search
size_t n1 = 0, n2 = 0, ndis = 0;
size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;

const IndexIVFPQ* index_ivfpq =
dynamic_cast<const IndexIVFPQ*>(storage);
Expand Down Expand Up @@ -814,7 +824,7 @@ void IndexHNSW2Level::search(
int candidates_size = hnsw.upper_beam;
MinimaxHeap candidates(candidates_size);

#pragma omp for reduction(+ : n1, n2, ndis)
#pragma omp for reduction(+ : n1, n2, ndis, nhops)
for (idx_t i = 0; i < n; i++) {
idx_t* idxi = labels + i * k;
float* simi = distances + i * k;
Expand Down Expand Up @@ -860,6 +870,7 @@ void IndexHNSW2Level::search(
n1 += search_stats.n1;
n2 += search_stats.n2;
ndis += search_stats.ndis;
nhops += search_stats.nhops;

vt.advance();
vt.advance();
Expand All @@ -868,7 +879,7 @@ void IndexHNSW2Level::search(
}
}

hnsw_stats.combine({n1, n2, ndis});
hnsw_stats.combine({n1, n2, ndis, nhops});
}
}

Expand Down
11 changes: 9 additions & 2 deletions thirdparty/faiss/faiss/IndexHNSW.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct IndexHNSW;
struct IndexHNSW : Index {
typedef HNSW::storage_idx_t storage_idx_t;

// the link strcuture
// the link structure
HNSW hnsw;

// the sequential storage
Expand Down Expand Up @@ -111,6 +111,8 @@ struct IndexHNSW : Index {
void link_singletons();

void permute_entries(const idx_t* perm);

DistanceComputer* get_distance_computer() const override;
};

/** Flat index topped with with a HNSW structure to access elements
Expand All @@ -127,7 +129,12 @@ struct IndexHNSWFlat : IndexHNSW {
*/
struct IndexHNSWPQ : IndexHNSW {
IndexHNSWPQ();
IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits = 8);
IndexHNSWPQ(
int d,
int pq_m,
int M,
int pq_nbits = 8,
MetricType metric = METRIC_L2);
void train(idx_t n, const float* x) override;
};

Expand Down
10 changes: 5 additions & 5 deletions thirdparty/faiss/faiss/IndexRefine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ template <class C>
static void reorder_2_heaps(
idx_t n,
idx_t k,
idx_t* labels,
float* distances,
idx_t* __restrict labels,
float* __restrict distances,
idx_t k_base,
const idx_t* base_labels,
const float* base_distances) {
#pragma omp parallel for
const idx_t* __restrict base_labels,
const float* __restrict base_distances) {
#pragma omp parallel for if (n > 1)
for (idx_t i = 0; i < n; i++) {
idx_t* idxo = labels + i * k;
float* diso = distances + i * k;
Expand Down
130 changes: 125 additions & 5 deletions thirdparty/faiss/faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,18 +409,22 @@ void search_neighbors_to_add(
**************************************************************/

/// greedily update a nearest vector at a given level
void greedy_update_nearest(
HNSWStats greedy_update_nearest(
const HNSW& hnsw,
DistanceComputer& qdis,
int level,
storage_idx_t& nearest,
float& d_nearest) {
HNSWStats stats;

for (;;) {
storage_idx_t prev_nearest = nearest;

size_t begin, end;
hnsw.neighbor_range(nearest, level, &begin, &end);
for (size_t i = begin; i < end; i++) {

size_t ndis = 0;
for (size_t i = begin; i < end; i++, ndis++) {
storage_idx_t v = hnsw.neighbors[i];
if (v < 0)
break;
Expand All @@ -430,8 +434,13 @@ void greedy_update_nearest(
d_nearest = dis;
}
}

// update stats
stats.ndis += ndis;
stats.nhops += 1;

if (nearest == prev_nearest) {
return;
return stats;
}
}
}
Expand Down Expand Up @@ -641,6 +650,7 @@ int search_from_candidates(
if (dis < threshold) {
if (res.add_result(dis, idx)) {
threshold = res.threshold;
nres += 1;
}
}
}
Expand Down Expand Up @@ -692,6 +702,7 @@ int search_from_candidates(
stats.n2++;
}
stats.ndis += ndis;
stats.nhops += nstep;
}

return nres;
Expand Down Expand Up @@ -814,6 +825,8 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
float dis = qdis(saved_j[icnt]);
add_to_heap(saved_j[icnt], dis);
}

stats.nhops += 1;
}

++stats.n1;
Expand Down Expand Up @@ -853,7 +866,9 @@ HNSWStats HNSW::search(
float d_nearest = qdis(nearest);

for (int level = max_level; level >= 1; level--) {
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
HNSWStats local_stats = greedy_update_nearest(
*this, qdis, level, nearest, d_nearest);
stats.combine(local_stats);
}

int ef = std::max(params ? params->efSearch : efSearch, k);
Expand Down Expand Up @@ -916,11 +931,23 @@ HNSWStats HNSW::search(
if (level == 0) {
nres = search_from_candidates(
*this, qdis, res, candidates, vt, stats, 0);
nres = std::min(nres, candidates_size);
} else {
const auto nres_prev = nres;

resh.begin(0);
nres = search_from_candidates(
*this, qdis, resh, candidates, vt, stats, level);
nres = std::min(nres, candidates_size);
resh.end();

// if the search on a particular level produces no improvements,
// then we need to repopulate candidates.
// search_from_candidates() will always damage candidates
// by doing 1 pop_min().
if (nres == 0) {
nres = nres_prev;
}
}
vt.advance();
}
Expand Down Expand Up @@ -970,6 +997,7 @@ void HNSW::search_level_0(
0,
nres,
params);
nres = std::min(nres, candidates_size);
}
} else if (search_type == 2) {
int candidates_size = std::max(efSearch, int(k));
Expand Down Expand Up @@ -1051,7 +1079,99 @@ void HNSW::MinimaxHeap::clear() {
nvalid = k = 0;
}

#ifdef __AVX2__
#ifdef __AVX512F__

int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
assert(k > 0);
static_assert(
std::is_same<storage_idx_t, int32_t>::value,
"This code expects storage_idx_t to be int32_t");

int32_t min_idx = -1;
float min_dis = std::numeric_limits<float>::infinity();

__m512i min_indices = _mm512_set1_epi32(-1);
__m512 min_distances =
_mm512_set1_ps(std::numeric_limits<float>::infinity());
__m512i current_indices = _mm512_setr_epi32(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
__m512i offset = _mm512_set1_epi32(16);

// The following loop tracks the rightmost index with the min distance.
// -1 index values are ignored.
const int k16 = (k / 16) * 16;
for (size_t iii = 0; iii < k16; iii += 16) {
__m512i indices =
_mm512_loadu_si512((const __m512i*)(ids.data() + iii));
__m512 distances = _mm512_loadu_ps(dis.data() + iii);

// This mask filters out -1 values among indices.
__mmask16 m1mask =
_mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);

__mmask16 dmask =
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
__mmask16 finalmask = m1mask | dmask;

const __m512i min_indices_new = _mm512_mask_blend_epi32(
finalmask, current_indices, min_indices);
const __m512 min_distances_new =
_mm512_mask_blend_ps(finalmask, distances, min_distances);

min_indices = min_indices_new;
min_distances = min_distances_new;

current_indices = _mm512_add_epi32(current_indices, offset);
}

// leftovers
if (k16 != k) {
const __mmask16 kmask = (1 << (k - k16)) - 1;

__m512i indices = _mm512_mask_loadu_epi32(
_mm512_set1_epi32(-1), kmask, ids.data() + k16);
__m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16);

// This mask filters out -1 values among indices.
__mmask16 m1mask =
_mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);

__mmask16 dmask =
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
__mmask16 finalmask = m1mask | dmask;

const __m512i min_indices_new = _mm512_mask_blend_epi32(
finalmask, current_indices, min_indices);
const __m512 min_distances_new =
_mm512_mask_blend_ps(finalmask, distances, min_distances);

min_indices = min_indices_new;
min_distances = min_distances_new;
}

// grab min distance
min_dis = _mm512_reduce_min_ps(min_distances);
// blend
__mmask16 mindmask =
_mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis));
// pick the max one
min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices);

if (min_idx == -1) {
return -1;
}

if (vmin_out) {
*vmin_out = min_dis;
}
int ret = ids[min_idx];
ids[min_idx] = -1;
--nvalid;
return ret;
}

#elif __AVX2__

int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
assert(k > 0);
static_assert(
Expand Down
Loading