diff --git a/thirdparty/faiss/faiss/IndexFlat.h b/thirdparty/faiss/faiss/IndexFlat.h index 165b06aaa..4da54a6b5 100644 --- a/thirdparty/faiss/faiss/IndexFlat.h +++ b/thirdparty/faiss/faiss/IndexFlat.h @@ -18,9 +18,6 @@ namespace faiss { /** Index that stores the full vectors and performs exhaustive search */ struct IndexFlat : IndexFlatCodes { - /// database vectors, size ntotal * d - std::vector xb; - explicit IndexFlat( idx_t d, ///< dimensionality of the input vectors MetricType metric = METRIC_L2, diff --git a/thirdparty/faiss/faiss/IndexHNSW.cpp b/thirdparty/faiss/faiss/IndexHNSW.cpp index c0bb81c05..f9732ed38 100644 --- a/thirdparty/faiss/faiss/IndexHNSW.cpp +++ b/thirdparty/faiss/faiss/IndexHNSW.cpp @@ -275,7 +275,7 @@ void hnsw_search( FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); efSearch = params->efSearch; } - size_t n1 = 0, n2 = 0, ndis = 0; + size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0; idx_t check_period = InterruptCallback::get_period_hint( hnsw.max_level * index->d * efSearch); @@ -283,7 +283,7 @@ void hnsw_search( for (idx_t i0 = 0; i0 < n; i0 += check_period) { idx_t i1 = std::min(i0 + check_period, n); -#pragma omp parallel +#pragma omp parallel if (i1 - i0 > 1) { VisitedTable vt(index->ntotal); typename BlockResultHandler::SingleResultHandler res(bres); @@ -291,7 +291,7 @@ void hnsw_search( std::unique_ptr dis( storage_distance_computer(index->storage)); -#pragma omp for reduction(+ : n1, n2, ndis) schedule(guided) +#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided) for (idx_t i = i0; i < i1; i++) { res.begin(i); dis->set_query(x + i * index->d); @@ -300,13 +300,14 @@ void hnsw_search( n1 += stats.n1; n2 += stats.n2; ndis += stats.ndis; + nhops += stats.nhops; res.end(); } } InterruptCallback::check(); } - hnsw_stats.combine({n1, n2, ndis}); + hnsw_stats.combine({n1, n2, ndis, nhops}); } } // anonymous namespace @@ -632,6 +633,10 @@ void IndexHNSW::permute_entries(const idx_t* perm) { hnsw.permute_entries(perm); } +DistanceComputer* IndexHNSW::get_distance_computer() const { + return storage->get_distance_computer(); +} + /************************************************************** * IndexHNSWFlat implementation **************************************************************/ @@ -655,8 +660,13 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric) IndexHNSWPQ::IndexHNSWPQ() = default; -IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits) - : IndexHNSW(new IndexPQ(d, pq_m, pq_nbits), M) { +IndexHNSWPQ::IndexHNSWPQ( + int d, + int pq_m, + int M, + int pq_nbits, + MetricType metric) + : IndexHNSW(new IndexPQ(d, pq_m, pq_nbits, metric), M) { own_fields = true; is_trained = false; } @@ -782,7 +792,7 @@ void IndexHNSW2Level::search( IndexHNSW::search(n, x, k, distances, labels); } else { // "mixed" search - size_t n1 = 0, n2 = 0, ndis = 0; + size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0; const IndexIVFPQ* index_ivfpq = dynamic_cast(storage); @@ -814,7 +824,7 @@ void IndexHNSW2Level::search( int candidates_size = hnsw.upper_beam; MinimaxHeap candidates(candidates_size); -#pragma omp for reduction(+ : n1, n2, ndis) +#pragma omp for reduction(+ : n1, n2, ndis, nhops) for (idx_t i = 0; i < n; i++) { idx_t* idxi = labels + i * k; float* simi = distances + i * k; @@ -860,6 +870,7 @@ void IndexHNSW2Level::search( n1 += search_stats.n1; n2 += search_stats.n2; ndis += search_stats.ndis; + nhops += search_stats.nhops; vt.advance(); vt.advance(); @@ -868,7 +879,7 @@ void IndexHNSW2Level::search( } } - hnsw_stats.combine({n1, n2, ndis}); + hnsw_stats.combine({n1, n2, ndis, nhops}); } } diff --git a/thirdparty/faiss/faiss/IndexHNSW.h b/thirdparty/faiss/faiss/IndexHNSW.h index 71807c653..0768eb88b 100644 --- a/thirdparty/faiss/faiss/IndexHNSW.h +++ b/thirdparty/faiss/faiss/IndexHNSW.h @@ -27,7 +27,7 @@ struct IndexHNSW; struct IndexHNSW : Index { typedef HNSW::storage_idx_t storage_idx_t; - // the link strcuture + // the link structure HNSW hnsw; // the sequential storage @@ -111,6 +111,8 @@ struct IndexHNSW : Index { void link_singletons(); void permute_entries(const idx_t* perm); + + DistanceComputer* get_distance_computer() const override; }; /** Flat index topped with with a HNSW structure to access elements @@ -127,7 +129,12 @@ struct IndexHNSWFlat : IndexHNSW { */ struct IndexHNSWPQ : IndexHNSW { IndexHNSWPQ(); - IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits = 8); + IndexHNSWPQ( + int d, + int pq_m, + int M, + int pq_nbits = 8, + MetricType metric = METRIC_L2); void train(idx_t n, const float* x) override; }; diff --git a/thirdparty/faiss/faiss/IndexRefine.cpp b/thirdparty/faiss/faiss/IndexRefine.cpp index c504718e5..a65664cc6 100644 --- a/thirdparty/faiss/faiss/IndexRefine.cpp +++ b/thirdparty/faiss/faiss/IndexRefine.cpp @@ -69,12 +69,12 @@ template static void reorder_2_heaps( idx_t n, idx_t k, - idx_t* labels, - float* distances, + idx_t* __restrict labels, + float* __restrict distances, idx_t k_base, - const idx_t* base_labels, - const float* base_distances) { -#pragma omp parallel for + const idx_t* __restrict base_labels, + const float* __restrict base_distances) { +#pragma omp parallel for if (n > 1) for (idx_t i = 0; i < n; i++) { idx_t* idxo = labels + i * k; float* diso = distances + i * k; diff --git a/thirdparty/faiss/faiss/impl/HNSW.cpp b/thirdparty/faiss/faiss/impl/HNSW.cpp index 3ba5f72f6..3647f0f6f 100644 --- a/thirdparty/faiss/faiss/impl/HNSW.cpp +++ b/thirdparty/faiss/faiss/impl/HNSW.cpp @@ -409,18 +409,22 @@ void search_neighbors_to_add( **************************************************************/ /// greedily update a nearest vector at a given level -void greedy_update_nearest( +HNSWStats greedy_update_nearest( const HNSW& hnsw, DistanceComputer& qdis, int level, storage_idx_t& nearest, float& d_nearest) { + HNSWStats stats; + for (;;) { storage_idx_t prev_nearest = nearest; size_t begin, end; hnsw.neighbor_range(nearest, level, &begin, &end); - for (size_t i = begin; i < end; i++) { + + size_t ndis = 0; + for (size_t i = begin; i < end; i++, ndis++) { storage_idx_t v = hnsw.neighbors[i]; if (v < 0) break; @@ -430,8 +434,13 @@ void greedy_update_nearest( d_nearest = dis; } } + + // update stats + stats.ndis += ndis; + stats.nhops += 1; + if (nearest == prev_nearest) { - return; + return stats; } } } @@ -641,6 +650,7 @@ int search_from_candidates( if (dis < threshold) { if (res.add_result(dis, idx)) { threshold = res.threshold; + nres += 1; } } } @@ -692,6 +702,7 @@ int search_from_candidates( stats.n2++; } stats.ndis += ndis; + stats.nhops += nstep; } return nres; @@ -814,6 +825,8 @@ std::priority_queue search_from_candidate_unbounded( float dis = qdis(saved_j[icnt]); add_to_heap(saved_j[icnt], dis); } + + stats.nhops += 1; } ++stats.n1; @@ -853,7 +866,9 @@ HNSWStats HNSW::search( float d_nearest = qdis(nearest); for (int level = max_level; level >= 1; level--) { - greedy_update_nearest(*this, qdis, level, nearest, d_nearest); + HNSWStats local_stats = greedy_update_nearest( + *this, qdis, level, nearest, d_nearest); + stats.combine(local_stats); } int ef = std::max(params ? params->efSearch : efSearch, k); @@ -916,11 +931,23 @@ HNSWStats HNSW::search( if (level == 0) { nres = search_from_candidates( *this, qdis, res, candidates, vt, stats, 0); + nres = std::min(nres, candidates_size); } else { + const auto nres_prev = nres; + resh.begin(0); nres = search_from_candidates( *this, qdis, resh, candidates, vt, stats, level); + nres = std::min(nres, candidates_size); resh.end(); + + // if the search on a particular level produces no improvements, + // then we need to repopulate candidates. + // search_from_candidates() will always damage candidates + // by doing 1 pop_min(). + if (nres == 0) { + nres = nres_prev; + } } vt.advance(); } @@ -970,6 +997,7 @@ void HNSW::search_level_0( 0, nres, params); + nres = std::min(nres, candidates_size); } } else if (search_type == 2) { int candidates_size = std::max(efSearch, int(k)); @@ -1051,7 +1079,99 @@ void HNSW::MinimaxHeap::clear() { nvalid = k = 0; } -#ifdef __AVX2__ +#ifdef __AVX512F__ + +int HNSW::MinimaxHeap::pop_min(float* vmin_out) { + assert(k > 0); + static_assert( + std::is_same::value, + "This code expects storage_idx_t to be int32_t"); + + int32_t min_idx = -1; + float min_dis = std::numeric_limits::infinity(); + + __m512i min_indices = _mm512_set1_epi32(-1); + __m512 min_distances = + _mm512_set1_ps(std::numeric_limits::infinity()); + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + __m512i offset = _mm512_set1_epi32(16); + + // The following loop tracks the rightmost index with the min distance. + // -1 index values are ignored. + const int k16 = (k / 16) * 16; + for (size_t iii = 0; iii < k16; iii += 16) { + __m512i indices = + _mm512_loadu_si512((const __m512i*)(ids.data() + iii)); + __m512 distances = _mm512_loadu_ps(dis.data() + iii); + + // This mask filters out -1 values among indices. + __mmask16 m1mask = + _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices); + + __mmask16 dmask = + _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); + __mmask16 finalmask = m1mask | dmask; + + const __m512i min_indices_new = _mm512_mask_blend_epi32( + finalmask, current_indices, min_indices); + const __m512 min_distances_new = + _mm512_mask_blend_ps(finalmask, distances, min_distances); + + min_indices = min_indices_new; + min_distances = min_distances_new; + + current_indices = _mm512_add_epi32(current_indices, offset); + } + + // leftovers + if (k16 != k) { + const __mmask16 kmask = (1 << (k - k16)) - 1; + + __m512i indices = _mm512_mask_loadu_epi32( + _mm512_set1_epi32(-1), kmask, ids.data() + k16); + __m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16); + + // This mask filters out -1 values among indices. + __mmask16 m1mask = + _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices); + + __mmask16 dmask = + _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); + __mmask16 finalmask = m1mask | dmask; + + const __m512i min_indices_new = _mm512_mask_blend_epi32( + finalmask, current_indices, min_indices); + const __m512 min_distances_new = + _mm512_mask_blend_ps(finalmask, distances, min_distances); + + min_indices = min_indices_new; + min_distances = min_distances_new; + } + + // grab min distance + min_dis = _mm512_reduce_min_ps(min_distances); + // blend + __mmask16 mindmask = + _mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis)); + // pick the max one + min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices); + + if (min_idx == -1) { + return -1; + } + + if (vmin_out) { + *vmin_out = min_dis; + } + int ret = ids[min_idx]; + ids[min_idx] = -1; + --nvalid; + return ret; +} + +#elif __AVX2__ + int HNSW::MinimaxHeap::pop_min(float* vmin_out) { assert(k > 0); static_assert( diff --git a/thirdparty/faiss/faiss/impl/HNSW.h b/thirdparty/faiss/faiss/impl/HNSW.h index f3aacf8a5..f376c9fcc 100644 --- a/thirdparty/faiss/faiss/impl/HNSW.h +++ b/thirdparty/faiss/faiss/impl/HNSW.h @@ -234,20 +234,23 @@ struct HNSW { }; struct HNSWStats { - size_t n1 = 0; /// numbner of vectors searched + size_t n1 = 0; /// number of vectors searched size_t n2 = - 0; /// number of queries for which the candidate list is exhasted - size_t ndis = 0; /// number of distances computed + 0; /// number of queries for which the candidate list is exhausted + size_t ndis = 0; /// number of distances computed + size_t nhops = 0; /// number of hops aka number of edges traversed void reset() { n1 = n2 = 0; ndis = 0; + nhops = 0; } void combine(const HNSWStats& other) { n1 += other.n1; n2 += other.n2; ndis += other.ndis; + nhops += other.nhops; } }; diff --git a/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h b/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h index 64e4c4a56..814c5aea1 100644 --- a/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h +++ b/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h @@ -263,8 +263,8 @@ struct Quantizer8bitDirect_avx512<16> : public Quantizer8bitDirect_avx<8> { FAISS_ALWAYS_INLINE __m512 reconstruct_16_components(const uint8_t* code, int i) const { - __m256i x16 = _mm256_loadu_si256((__m256i*)(code + i)); // 16 * int8 - __m512i y16 = _mm512_cvtepu8_epi16(x16); // 16 * int32 + __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 + __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 return _mm512_cvtepi32_ps(y16); // 16 * float32 } }; @@ -295,8 +295,8 @@ struct Quantizer8bitDirectSigned_avx512<16> : public Quantizer8bitDirectSigned_a FAISS_ALWAYS_INLINE __m512 reconstruct_16_components(const uint8_t* code, int i) const { - __m256i x16 = _mm256_loadu_si256((__m256i*)(code + i)); // 16 * int8 - __m512i y16 = _mm512_cvtepu8_epi16(x16); // 16 * int32 + __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 + __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 __m512i c16 = _mm512_set1_epi32(128); __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes return _mm512_cvtepi32_ps(z16); // 16 * float32 diff --git a/thirdparty/faiss/faiss/impl/code_distance/code_distance-avx512.h b/thirdparty/faiss/faiss/impl/code_distance/code_distance-avx512.h new file mode 100644 index 000000000..6c6afc7e0 --- /dev/null +++ b/thirdparty/faiss/faiss/impl/code_distance/code_distance-avx512.h @@ -0,0 +1,248 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef __AVX512F__ + +#include + +#include + +#include +#include + +namespace faiss { + +// According to experiments, the AVX-512 version may be SLOWER than +// the AVX2 version, which is somewhat unexpected. +// This version is not used for now, but it may be used later. +// +// TODO: test for AMD CPUs. + +template +typename std::enable_if::value, float>:: + type inline distance_single_code_avx512( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + const uint8_t* code) { + // default implementation + return distance_single_code_generic(M, nbits, sim_table, code); +} + +template +typename std::enable_if::value, float>:: + type inline distance_single_code_avx512( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + const uint8_t* code0) { + float result0 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 1; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m512i vksub = _mm512_set1_epi32(ksub); + __m512i offsets_0 = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m512 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm512_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m512i indices_to_read_from = + _mm512_add_epi32(idx1, offsets_0); + + // gather 16 values, similar to 16 operations of tab[idx] + __m512 collected = _mm512_i32gather_ps( + indices_to_read_from, tab, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm512_add_ps(partialSums[j], collected); + } + tab += ksub * 16; + } + + // horizontal sum for partialSum + result0 += _mm512_reduce_add_ps(partialSums[0]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + tab += ksub; + } + } + + return result0; +} + +template +typename std::enable_if::value, void>:: + type + distance_four_codes_avx512( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + distance_four_codes_generic( + M, + nbits, + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); +} + +// Combines 4 operations of distance_single_code() +template +typename std::enable_if::value, void>::type +distance_four_codes_avx512( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + result0 = 0; + result1 = 0; + result2 = 0; + result3 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 4; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m512i vksub = _mm512_set1_epi32(ksub); + __m512i offsets_0 = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m512 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm512_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); + mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); + mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m512i indices_to_read_from = + _mm512_add_epi32(idx1, offsets_0); + + // gather 16 values, similar to 16 operations of tab[idx] + __m512 collected = _mm512_i32gather_ps( + indices_to_read_from, tab, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm512_add_ps(partialSums[j], collected); + } + tab += ksub * 16; + } + + // horizontal sum for partialSum + result0 += _mm512_reduce_add_ps(partialSums[0]); + result1 += _mm512_reduce_add_ps(partialSums[1]); + result2 += _mm512_reduce_add_ps(partialSums[2]); + result3 += _mm512_reduce_add_ps(partialSums[3]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + PQDecoder8 decoder1(code1 + m, nbits); + PQDecoder8 decoder2(code2 + m, nbits); + PQDecoder8 decoder3(code3 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + result1 += tab[decoder1.decode()]; + result2 += tab[decoder2.decode()]; + result3 += tab[decoder3.decode()]; + tab += ksub; + } + } +} + +} // namespace faiss + +#endif diff --git a/thirdparty/faiss/faiss/impl/code_distance/code_distance_avx512.h b/thirdparty/faiss/faiss/impl/code_distance/code_distance_avx512.h deleted file mode 100644 index 296e0df1b..000000000 --- a/thirdparty/faiss/faiss/impl/code_distance/code_distance_avx512.h +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -// // // AVX-512 version. It is not used, but let it be for the future -// // // needs. -// // template -// // typename std::enable_if<(std::is_same::value), void>:: -// // type distance_four_codes( -// // const uint8_t* __restrict code0, -// // const uint8_t* __restrict code1, -// // const uint8_t* __restrict code2, -// // const uint8_t* __restrict code3, -// // float& result0, -// // float& result1, -// // float& result2, -// // float& result3 -// // ) const { -// // result0 = 0; -// // result1 = 0; -// // result2 = 0; -// // result3 = 0; - -// // size_t m = 0; -// // const size_t pqM16 = pq.M / 16; - -// // constexpr intptr_t N = 4; - -// // const float* tab = sim_table; - -// // if (pqM16 > 0) { -// // // process 16 values per loop -// // const __m512i ksub = _mm512_set1_epi32(pq.ksub); -// // __m512i offsets_0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, -// // 8, 9, 10, 11, 12, 13, 14, 15); -// // offsets_0 = _mm512_mullo_epi32(offsets_0, ksub); - -// // // accumulators of partial sums -// // __m512 partialSums[N]; -// // for (intptr_t j = 0; j < N; j++) { -// // partialSums[j] = _mm512_setzero_ps(); -// // } - -// // // loop -// // for (m = 0; m < pqM16 * 16; m += 16) { -// // // load 16 uint8 values -// // __m128i mm1[N]; -// // mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); -// // mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); -// // mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); -// // mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); - -// // // process first 8 codes -// // for (intptr_t j = 0; j < N; j++) { -// // // convert uint8 values (low part of __m128i) to int32 -// // // values -// // const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); - -// // // add offsets -// // const __m512i indices_to_read_from = -// // _mm512_add_epi32(idx1, offsets_0); - -// // // gather 8 values, similar to 8 operations of -// // // tab[idx] -// // __m512 collected = -// // _mm512_i32gather_ps( -// // indices_to_read_from, tab, sizeof(float)); - -// // // collect partial sums -// // partialSums[j] = _mm512_add_ps(partialSums[j], -// // collected); -// // } -// // tab += pq.ksub * 16; - -// // } - -// // // horizontal sum for partialSum -// // result0 += _mm512_reduce_add_ps(partialSums[0]); -// // result1 += _mm512_reduce_add_ps(partialSums[1]); -// // result2 += _mm512_reduce_add_ps(partialSums[2]); -// // result3 += _mm512_reduce_add_ps(partialSums[3]); -// // } - -// // // -// // if (m < pq.M) { -// // // process leftovers -// // PQDecoder decoder0(code0 + m, pq.nbits); -// // PQDecoder decoder1(code1 + m, pq.nbits); -// // PQDecoder decoder2(code2 + m, pq.nbits); -// // PQDecoder decoder3(code3 + m, pq.nbits); -// // for (; m < pq.M; m++) { -// // result0 += tab[decoder0.decode()]; -// // result1 += tab[decoder1.decode()]; -// // result2 += tab[decoder2.decode()]; -// // result3 += tab[decoder3.decode()]; -// // tab += pq.ksub; -// // } -// // } -// // } diff --git a/thirdparty/faiss/faiss/utils/Heap.h b/thirdparty/faiss/faiss/utils/Heap.h index cdb714f4d..b67707ecb 100644 --- a/thirdparty/faiss/faiss/utils/Heap.h +++ b/thirdparty/faiss/faiss/utils/Heap.h @@ -30,6 +30,7 @@ #include #include +#include #include @@ -200,6 +201,110 @@ inline void maxheap_replace_top( heap_replace_top>(k, bh_val, bh_ids, val, ids); } +/******************************************************************* + * Basic heap> ops: push and pop + *******************************************************************/ + +// This section contains a heap implementation that works with +// std::pair elements. + +/** Pops the top element from the heap defined by bh_val[0..k-1] and + * bh_ids[0..k-1]. on output the element at k-1 is undefined. + */ +template +inline void heap_pop(size_t k, std::pair* bh) { + bh--; /* Use 1-based indexing for easier node->child translation */ + typename C::T val = bh[k].first; + typename C::TI id = bh[k].second; + size_t i = 1, i1, i2; + while (1) { + i1 = i << 1; + i2 = i1 + 1; + if (i1 > k) + break; + if ((i2 == k + 1) || + C::cmp2(bh[i1].first, bh[i2].first, bh[i1].second, bh[i2].second)) { + if (C::cmp2(val, bh[i1].first, id, bh[i1].second)) { + break; + } + bh[i] = bh[i1]; + i = i1; + } else { + if (C::cmp2(val, bh[i2].first, id, bh[i2].second)) { + break; + } + bh[i] = bh[i2]; + i = i2; + } + } + bh[i] = bh[k]; +} + +/** Pushes the element (val, ids) into the heap bh_val[0..k-2] and + * bh_ids[0..k-2]. on output the element at k-1 is defined. + */ +template +inline void heap_push( + size_t k, + std::pair* bh, + typename C::T val, + typename C::TI id) { + bh--; /* Use 1-based indexing for easier node->child translation */ + size_t i = k, i_father; + while (i > 1) { + i_father = i >> 1; + auto bh_v = bh[i_father]; + if (!C::cmp2(val, bh_v.first, id, bh_v.second)) { + /* the heap structure is ok */ + break; + } + bh[i] = bh_v; + i = i_father; + } + bh[i] = std::make_pair(val, id); +} + +/** + * Replaces the top element from the heap defined by bh_val[0..k-1] and + * bh_ids[0..k-1], and for identical bh_val[] values also sorts by bh_ids[] + * values. + */ +template +inline void heap_replace_top( + size_t k, + std::pair* bh, + typename C::T val, + typename C::TI id) { + bh--; /* Use 1-based indexing for easier node->child translation */ + size_t i = 1, i1, i2; + while (1) { + i1 = i << 1; + i2 = i1 + 1; + if (i1 > k) { + break; + } + + // Note that C::cmp2() is a bool function answering + // `(a1 > b1) || ((a1 == b1) && (a2 > b2))` for max + // heap and same with the `<` sign for min heap. + if ((i2 == k + 1) || + C::cmp2(bh[i1].first, bh[i2].first, bh[i1].second, bh[i2].second)) { + if (C::cmp2(val, bh[i1].first, id, bh[i1].second)) { + break; + } + bh[i] = bh[i1]; + i = i1; + } else { + if (C::cmp2(val, bh[i2].first, id, bh[i2].second)) { + break; + } + bh[i] = bh[i2]; + i = i2; + } + } + bh[i] = std::make_pair(val, id); +} + /******************************************************************* * Heap initialization *******************************************************************/