From 5c729b1b648f4bb9c21bc0356c96e9b4dbe0578c Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Thu, 21 Mar 2024 12:40:03 -0400 Subject: [PATCH 1/6] more filtering Signed-off-by: Alexandr Guzhva --- faiss/IVFlib.cpp | 3 +- faiss/IndexFastScan.cpp | 10 +- faiss/IndexIVFFastScan.cpp | 179 +++++++++++++++++++---------- faiss/IndexIVFFastScan.h | 21 ++-- faiss/impl/simd_result_handlers.h | 183 ++++++++++++++++++++++-------- 5 files changed, 277 insertions(+), 119 deletions(-) diff --git a/faiss/IVFlib.cpp b/faiss/IVFlib.cpp index 91aa7af7f3..d7dfbaf582 100644 --- a/faiss/IVFlib.cpp +++ b/faiss/IVFlib.cpp @@ -352,7 +352,8 @@ void search_with_parameters( const IndexIVF* index_ivf = dynamic_cast(index); FAISS_THROW_IF_NOT(index_ivf); - index_ivf->quantizer->search(n, x, params->nprobe, Dq.data(), Iq.data()); + SearchParameters* quantizer_params = (params) ? params->quantizer_params : nullptr; + index_ivf->quantizer->search(n, x, params->nprobe, Dq.data(), Iq.data(), quantizer_params); if (nb_dis_ptr) { *nb_dis_ptr = count_ndis(index_ivf, n * params->nprobe, Iq.data()); diff --git a/faiss/IndexFastScan.cpp b/faiss/IndexFastScan.cpp index 2dfb2f55fd..529465da3e 100644 --- a/faiss/IndexFastScan.cpp +++ b/faiss/IndexFastScan.cpp @@ -189,6 +189,7 @@ void estimators_from_tables_generic( dt += index.ksub; } } + if (C::cmp(heap_dis[0], dis)) { heap_pop(k, heap_dis, heap_ids); heap_push(k, heap_dis, heap_ids, dis, j); @@ -203,17 +204,18 @@ ResultHandlerCompare* make_knn_handler( idx_t k, size_t ntotal, float* distances, - idx_t* labels) { + idx_t* labels, + const IDSelector* sel = nullptr) { using HeapHC = HeapHandler; using ReservoirHC = ReservoirHandler; using SingleResultHC = SingleResultHandler; if (k == 1) { - return new SingleResultHC(n, ntotal, distances, labels); + return new SingleResultHC(n, ntotal, distances, labels, sel); } else if (impl % 2 == 0) { - return new HeapHC(n, ntotal, k, distances, labels); + return new HeapHC(n, ntotal, k, distances, labels, sel); } else /* if (impl % 2 == 1) */ { - return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels); + return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel); } } diff --git a/faiss/IndexIVFFastScan.cpp b/faiss/IndexIVFFastScan.cpp index 00bc6c823e..e454e9140d 100644 --- a/faiss/IndexIVFFastScan.cpp +++ b/faiss/IndexIVFFastScan.cpp @@ -211,7 +211,7 @@ void estimators_from_tables_generic( int64_t* heap_ids, const NormTableScaler* scaler) { using accu_t = typename C::T; - int nscale = scaler ? scaler->nscale : 0; + size_t nscale = scaler ? scaler->nscale : 0; for (size_t j = 0; j < ncodes; ++j) { BitstringReader bsr(codes + j * index.code_size, index.code_size); accu_t dis = bias; @@ -271,7 +271,7 @@ void IndexIVFFastScan::compute_LUT_uint8( } #pragma omp parallel for if (n > 100) - for (int64_t i = 0; i < n; i++) { + for (size_t i = 0; i < n; i++) { const float* t_in = dis_tables_float.get() + i * dim123; const float* b_in = nullptr; uint8_t* t_out = dis_tables.get() + i * dim123_2; @@ -306,11 +306,16 @@ void IndexIVFFastScan::search( idx_t k, float* distances, idx_t* labels, - const SearchParameters* params) const { - auto paramsi = dynamic_cast(params); - FAISS_THROW_IF_NOT_MSG(!params || paramsi, "need IVFSearchParameters"); + const SearchParameters* params_in) const { + const IVFSearchParameters* params = nullptr; + if (params_in) { + params = dynamic_cast(params_in); + FAISS_THROW_IF_NOT_MSG( + params, "IndexIVFFastScan params have incorrect type"); + } + search_preassigned( - n, x, k, nullptr, nullptr, distances, labels, false, paramsi); + n, x, k, nullptr, nullptr, distances, labels, false, params); } void IndexIVFFastScan::search_preassigned( @@ -326,18 +331,17 @@ void IndexIVFFastScan::search_preassigned( IndexIVFStats* stats) const { size_t nprobe = this->nprobe; if (params) { - FAISS_THROW_IF_NOT_MSG( - !params->quantizer_params, "quantizer params not supported"); FAISS_THROW_IF_NOT(params->max_codes == 0); nprobe = params->nprobe; } + FAISS_THROW_IF_NOT_MSG( !store_pairs, "store_pairs not supported for this index"); FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index"); FAISS_THROW_IF_NOT(k > 0); const CoarseQuantized cq = {nprobe, centroid_dis, assign}; - search_dispatch_implem(n, x, k, distances, labels, cq, nullptr); + search_dispatch_implem(n, x, k, distances, labels, cq, nullptr, params); } void IndexIVFFastScan::range_search( @@ -345,10 +349,18 @@ void IndexIVFFastScan::range_search( const float* x, float radius, RangeSearchResult* result, - const SearchParameters* params) const { - FAISS_THROW_IF_NOT(!params); + const SearchParameters* params_in) const { + size_t nprobe = this->nprobe; + const IVFSearchParameters* params = nullptr; + if (params_in) { + params = dynamic_cast(params_in); + FAISS_THROW_IF_NOT_MSG( + params, "IndexIVFFastScan params have incorrect type"); + nprobe = params->nprobe; + } + const CoarseQuantized cq = {nprobe, nullptr, nullptr}; - range_search_dispatch_implem(n, x, radius, *result, cq, nullptr); + range_search_dispatch_implem(n, x, radius, *result, cq, nullptr, params); } namespace { @@ -359,17 +371,18 @@ ResultHandlerCompare* make_knn_handler_fixC( idx_t n, idx_t k, float* distances, - idx_t* labels) { + idx_t* labels, + const IDSelector* sel) { using HeapHC = HeapHandler; using ReservoirHC = ReservoirHandler; using SingleResultHC = SingleResultHandler; if (k == 1) { - return new SingleResultHC(n, 0, distances, labels); + return new SingleResultHC(n, 0, distances, labels, sel); } else if (impl % 2 == 0) { - return new HeapHC(n, 0, k, distances, labels); + return new HeapHC(n, 0, k, distances, labels, sel); } else /* if (impl % 2 == 1) */ { - return new ReservoirHC(n, 0, k, 2 * k, distances, labels); + return new ReservoirHC(n, 0, k, 2 * k, distances, labels, sel); } } @@ -379,13 +392,14 @@ SIMDResultHandlerToFloat* make_knn_handler( idx_t n, idx_t k, float* distances, - idx_t* labels) { + idx_t* labels, + const IDSelector* sel) { if (is_max) { return make_knn_handler_fixC>( - impl, n, k, distances, labels); + impl, n, k, distances, labels, sel); } else { return make_knn_handler_fixC>( - impl, n, k, distances, labels); + impl, n, k, distances, labels, sel); } } @@ -402,10 +416,15 @@ struct CoarseQuantizedWithBuffer : CoarseQuantized { std::vector ids_buffer; std::vector dis_buffer; - void quantize(const Index* quantizer, idx_t n, const float* x) { + void quantize( + const Index* quantizer, + idx_t n, + const float* x, + const SearchParameters* quantizer_params + ) { dis_buffer.resize(nprobe * n); ids_buffer.resize(nprobe * n); - quantizer->search(n, x, nprobe, dis_buffer.data(), ids_buffer.data()); + quantizer->search(n, x, nprobe, dis_buffer.data(), ids_buffer.data(), quantizer_params); dis = dis_buffer.data(); ids = ids_buffer.data(); } @@ -421,8 +440,8 @@ struct CoarseQuantizedSlice : CoarseQuantizedWithBuffer { } } - void quantize_slice(const Index* quantizer, const float* x) { - quantize(quantizer, i1 - i0, x + quantizer->d * i0); + void quantize_slice(const Index* quantizer, const float* x, const SearchParameters* quantizer_params) { + quantize(quantizer, i1 - i0, x + quantizer->d * i0, quantizer_params); } }; @@ -459,7 +478,13 @@ void IndexIVFFastScan::search_dispatch_implem( float* distances, idx_t* labels, const CoarseQuantized& cq_in, - const NormTableScaler* scaler) const { + const NormTableScaler* scaler, + const IVFSearchParameters* params) const { + const idx_t nprobe = params ? params->nprobe : this->nprobe; + const IDSelector* sel = (params) ? params->sel : nullptr; + const SearchParameters* quantizer_params = + params ? params->quantizer_params : nullptr; + bool is_max = !is_similarity_metric(metric_type); using RH = SIMDResultHandlerToFloat; @@ -489,52 +514,69 @@ void IndexIVFFastScan::search_dispatch_implem( } CoarseQuantizedWithBuffer cq(cq_in); + cq.nprobe = nprobe; if (!cq.done() && !multiple_threads) { // we do the coarse quantization here execpt when search is // sliced over threads (then it is more efficient to have each thread do // its own coarse quantization) - cq.quantize(quantizer, n, x); + cq.quantize(quantizer, n, x, quantizer_params); } if (impl == 1) { if (is_max) { search_implem_1>( - n, x, k, distances, labels, cq, scaler); + n, x, k, distances, labels, cq, scaler, params); } else { search_implem_1>( - n, x, k, distances, labels, cq, scaler); + n, x, k, distances, labels, cq, scaler, params); } } else if (impl == 2) { if (is_max) { search_implem_2>( - n, x, k, distances, labels, cq, scaler); + n, x, k, distances, labels, cq, scaler, params); } else { search_implem_2>( - n, x, k, distances, labels, cq, scaler); + n, x, k, distances, labels, cq, scaler, params); } - } else if (impl >= 10 && impl <= 15) { size_t ndis = 0, nlist_visited = 0; if (!multiple_threads) { // clang-format off if (impl == 12 || impl == 13) { - std::unique_ptr handler(make_knn_handler(is_max, impl, n, k, distances, labels)); + std::unique_ptr handler( + make_knn_handler( + is_max, + impl, + n, + k, + distances, + labels, sel + ) + ); search_implem_12( n, x, *handler.get(), - cq, &ndis, &nlist_visited, scaler); - + cq, &ndis, &nlist_visited, scaler, params); } else if (impl == 14 || impl == 15) { - search_implem_14( n, x, k, distances, labels, - cq, impl, scaler); + cq, impl, scaler, params); } else { - std::unique_ptr handler(make_knn_handler(is_max, impl, n, k, distances, labels)); + std::unique_ptr handler( + make_knn_handler( + is_max, + impl, + n, + k, + distances, + labels, + sel + ) + ); search_implem_10( n, x, *handler.get(), cq, - &ndis, &nlist_visited, scaler); + &ndis, &nlist_visited, scaler, params); } // clang-format on } else { @@ -543,7 +585,7 @@ void IndexIVFFastScan::search_dispatch_implem( if (impl == 14 || impl == 15) { // this might require slicing if there are too // many queries (for now we keep this simple) - search_implem_14(n, x, k, distances, labels, cq, impl, scaler); + search_implem_14(n, x, k, distances, labels, cq, impl, scaler, params); } else { #pragma omp parallel for reduction(+ : ndis, nlist_visited) for (int slice = 0; slice < nslice; slice++) { @@ -553,19 +595,19 @@ void IndexIVFFastScan::search_dispatch_implem( idx_t* lab_i = labels + i0 * k; CoarseQuantizedSlice cq_i(cq, i0, i1); if (!cq_i.done()) { - cq_i.quantize_slice(quantizer, x); + cq_i.quantize_slice(quantizer, x, quantizer_params); } std::unique_ptr handler(make_knn_handler( - is_max, impl, i1 - i0, k, dis_i, lab_i)); + is_max, impl, i1 - i0, k, dis_i, lab_i, sel)); // clang-format off if (impl == 12 || impl == 13) { search_implem_12( i1 - i0, x + i0 * d, *handler.get(), - cq_i, &ndis, &nlist_visited, scaler); + cq_i, &ndis, &nlist_visited, scaler, params); } else { search_implem_10( i1 - i0, x + i0 * d, *handler.get(), - cq_i, &ndis, &nlist_visited, scaler); + cq_i, &ndis, &nlist_visited, scaler, params); } // clang-format on } @@ -585,7 +627,13 @@ void IndexIVFFastScan::range_search_dispatch_implem( float radius, RangeSearchResult& rres, const CoarseQuantized& cq_in, - const NormTableScaler* scaler) const { + const NormTableScaler* scaler, + const IVFSearchParameters* params) const { + // const idx_t nprobe = params ? params->nprobe : this->nprobe; + const IDSelector* sel = (params) ? params->sel : nullptr; + const SearchParameters* quantizer_params = + params ? params->quantizer_params : nullptr; + bool is_max = !is_similarity_metric(metric_type); if (n == 0) { @@ -613,7 +661,7 @@ void IndexIVFFastScan::range_search_dispatch_implem( } if (!multiple_threads && !cq.done()) { - cq.quantize(quantizer, n, x); + cq.quantize(quantizer, n, x, quantizer_params); } size_t ndis = 0, nlist_visited = 0; @@ -622,10 +670,10 @@ void IndexIVFFastScan::range_search_dispatch_implem( std::unique_ptr handler; if (is_max) { handler.reset(new RangeHandler, true>( - rres, radius, 0)); + rres, radius, 0, sel)); } else { handler.reset(new RangeHandler, true>( - rres, radius, 0)); + rres, radius, 0, sel)); } if (impl == 12) { search_implem_12( @@ -649,17 +697,17 @@ void IndexIVFFastScan::range_search_dispatch_implem( idx_t i1 = n * (slice + 1) / nslice; CoarseQuantizedSlice cq_i(cq, i0, i1); if (!cq_i.done()) { - cq_i.quantize_slice(quantizer, x); + cq_i.quantize_slice(quantizer, x, quantizer_params); } std::unique_ptr handler; if (is_max) { handler.reset(new PartialRangeHandler< CMax, - true>(pres, radius, 0, i0, i1)); + true>(pres, radius, 0, i0, i1, sel)); } else { handler.reset(new PartialRangeHandler< CMin, - true>(pres, radius, 0, i0, i1)); + true>(pres, radius, 0, i0, i1, sel)); } if (impl == 12 || impl == 13) { @@ -670,7 +718,8 @@ void IndexIVFFastScan::range_search_dispatch_implem( cq_i, &ndis, &nlist_visited, - scaler); + scaler, + params); } else { search_implem_10( i1 - i0, @@ -679,7 +728,8 @@ void IndexIVFFastScan::range_search_dispatch_implem( cq_i, &ndis, &nlist_visited, - scaler); + scaler, + params); } } pres.finalize(); @@ -699,7 +749,8 @@ void IndexIVFFastScan::search_implem_1( float* distances, idx_t* labels, const CoarseQuantized& cq, - const NormTableScaler* scaler) const { + const NormTableScaler* scaler, + const IVFSearchParameters* params) const { FAISS_THROW_IF_NOT(orig_invlists); size_t dim12 = ksub * M; @@ -766,7 +817,8 @@ void IndexIVFFastScan::search_implem_2( float* distances, idx_t* labels, const CoarseQuantized& cq, - const NormTableScaler* scaler) const { + const NormTableScaler* scaler, + const IVFSearchParameters* params) const { FAISS_THROW_IF_NOT(orig_invlists); size_t dim12 = ksub * M2; @@ -848,7 +900,12 @@ void IndexIVFFastScan::search_implem_10( const CoarseQuantized& cq, size_t* ndis_out, size_t* nlist_out, - const NormTableScaler* scaler) const { + const NormTableScaler* scaler, + const IVFSearchParameters* params) const { + const size_t max_codes = params ? params->max_codes : this->max_codes; + const SearchParameters* quantizer_params = + params ? params->quantizer_params : nullptr; + size_t dim12 = ksub * M2; AlignedTable dis_tables; AlignedTable biases; @@ -909,6 +966,7 @@ void IndexIVFFastScan::search_implem_10( ndis++; } } + handler.end(); *ndis_out = ndis; *nlist_out = nlist; @@ -921,7 +979,8 @@ void IndexIVFFastScan::search_implem_12( const CoarseQuantized& cq, size_t* ndis_out, size_t* nlist_out, - const NormTableScaler* scaler) const { + const NormTableScaler* scaler, + const IVFSearchParameters* params) const { if (n == 0) { // does not work well with reservoir return; } @@ -933,6 +992,7 @@ void IndexIVFFastScan::search_implem_12( std::unique_ptr normalizers(new float[2 * n]); compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get()); + handler.begin(skip & 16 ? nullptr : normalizers.get()); struct QC { @@ -958,6 +1018,7 @@ void IndexIVFFastScan::search_implem_12( return a.list_no < b.list_no; }); } + // prepare the result handlers int qbs2 = this->qbs2 ? this->qbs2 : 11; @@ -1049,12 +1110,15 @@ void IndexIVFFastScan::search_implem_14( idx_t* labels, const CoarseQuantized& cq, int impl, - const NormTableScaler* scaler) const { + const NormTableScaler* scaler, + const IVFSearchParameters* params) const { if (n == 0) { // does not work well with reservoir return; } FAISS_THROW_IF_NOT(bbs == 32); + const IDSelector* sel = params ? params->sel : nullptr; + size_t dim12 = ksub * M2; AlignedTable dis_tables; AlignedTable biases; @@ -1157,7 +1221,7 @@ void IndexIVFFastScan::search_implem_14( // prepare the result handlers std::unique_ptr handler(make_knn_handler( - is_max, impl, n, k, local_dis.data(), local_idx.data())); + is_max, impl, n, k, local_dis.data(), local_idx.data(), sel)); handler->begin(normalizers.get()); int qbs2 = this->qbs2 ? this->qbs2 : 11; @@ -1167,6 +1231,7 @@ void IndexIVFFastScan::search_implem_14( tmp_bias.resize(qbs2); handler->dbias = tmp_bias.data(); } + std::set q_set; uint64_t t_copy_pack = 0, t_scan = 0; #pragma omp for schedule(dynamic) diff --git a/faiss/IndexIVFFastScan.h b/faiss/IndexIVFFastScan.h index 159a3a7098..9d4c4910d3 100644 --- a/faiss/IndexIVFFastScan.h +++ b/faiss/IndexIVFFastScan.h @@ -148,7 +148,8 @@ struct IndexIVFFastScan : IndexIVF { float* distances, idx_t* labels, const CoarseQuantized& cq, - const NormTableScaler* scaler) const; + const NormTableScaler* scaler, + const IVFSearchParameters* params = nullptr) const; void range_search_dispatch_implem( idx_t n, @@ -156,7 +157,8 @@ struct IndexIVFFastScan : IndexIVF { float radius, RangeSearchResult& rres, const CoarseQuantized& cq_in, - const NormTableScaler* scaler) const; + const NormTableScaler* scaler, + const IVFSearchParameters* params = nullptr) const; // impl 1 and 2 are just for verification template @@ -167,7 +169,8 @@ struct IndexIVFFastScan : IndexIVF { float* distances, idx_t* labels, const CoarseQuantized& cq, - const NormTableScaler* scaler) const; + const NormTableScaler* scaler, + const IVFSearchParameters* params = nullptr) const; template void search_implem_2( @@ -177,7 +180,8 @@ struct IndexIVFFastScan : IndexIVF { float* distances, idx_t* labels, const CoarseQuantized& cq, - const NormTableScaler* scaler) const; + const NormTableScaler* scaler, + const IVFSearchParameters* params = nullptr) const; // implem 10 and 12 are not multithreaded internally, so // export search stats @@ -188,7 +192,8 @@ struct IndexIVFFastScan : IndexIVF { const CoarseQuantized& cq, size_t* ndis_out, size_t* nlist_out, - const NormTableScaler* scaler) const; + const NormTableScaler* scaler, + const IVFSearchParameters* params = nullptr) const; void search_implem_12( idx_t n, @@ -197,7 +202,8 @@ struct IndexIVFFastScan : IndexIVF { const CoarseQuantized& cq, size_t* ndis_out, size_t* nlist_out, - const NormTableScaler* scaler) const; + const NormTableScaler* scaler, + const IVFSearchParameters* params = nullptr) const; // implem 14 is multithreaded internally across nprobes and queries void search_implem_14( @@ -208,7 +214,8 @@ struct IndexIVFFastScan : IndexIVF { idx_t* labels, const CoarseQuantized& cq, int impl, - const NormTableScaler* scaler) const; + const NormTableScaler* scaler, + const IVFSearchParameters* params = nullptr) const; // reconstruct vectors from packed invlists void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons) diff --git a/faiss/impl/simd_result_handlers.h b/faiss/impl/simd_result_handlers.h index 633d480990..d5c62b0304 100644 --- a/faiss/impl/simd_result_handlers.h +++ b/faiss/impl/simd_result_handlers.h @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -57,8 +58,7 @@ struct SIMDResultHandlerToFloat : SIMDResultHandler { nullptr; // table of biases to add to each query (for IVF L2 search) const float* normalizers = nullptr; // size 2 * nq, to convert - SIMDResultHandlerToFloat(size_t nq, size_t ntotal) - : nq(nq), ntotal(ntotal) {} + SIMDResultHandlerToFloat(size_t nq, size_t ntotal) : nq(nq), ntotal(ntotal) {} virtual void begin(const float* norms) { normalizers = norms; @@ -137,6 +137,7 @@ struct FixedStorageHandler : SIMDResultHandler { } } } + virtual ~FixedStorageHandler() {} }; @@ -150,8 +151,10 @@ struct ResultHandlerCompare : SIMDResultHandlerToFloat { int64_t i0 = 0; // query origin int64_t j0 = 0; // db origin - ResultHandlerCompare(size_t nq, size_t ntotal) - : SIMDResultHandlerToFloat(nq, ntotal) { + const IDSelector* sel; + + ResultHandlerCompare(size_t nq, size_t ntotal, const IDSelector* sel_in) + : SIMDResultHandlerToFloat(nq, ntotal), sel{sel_in} { this->is_CMax = C::is_max; this->sizeof_ids = sizeof(typename C::TI); this->with_fields = with_id_map; @@ -217,6 +220,20 @@ struct ResultHandlerCompare : SIMDResultHandlerToFloat { return lt_mask; } + uint32_t get_lt_mask_for_range_search(size_t b) { + uint32_t lt_mask = 0xffffffff; + + uint64_t idx = j0 + b * 32; + if (idx + 32 > ntotal) { + if (idx >= ntotal) { + return 0; + } + int nbit = (ntotal - idx); + lt_mask &= (uint32_t(1) << nbit) - 1; + } + return lt_mask; + } + virtual ~ResultHandlerCompare() {} }; @@ -232,9 +249,9 @@ struct SingleResultHandler : ResultHandlerCompare { float* dis; int64_t* ids; - SingleResultHandler(size_t nq, size_t ntotal, float* dis, int64_t* ids) - : RHC(nq, ntotal), idis(nq), dis(dis), ids(ids) { - for (int i = 0; i < nq; i++) { + SingleResultHandler(size_t nq, size_t ntotal, float* dis, int64_t* ids, const IDSelector* sel_in) + : RHC(nq, ntotal, sel_in), idis(nq), dis(dis), ids(ids) { + for (size_t i = 0; i < nq; i++) { ids[i] = -1; idis[i] = C::neutral(); } @@ -256,20 +273,37 @@ struct SingleResultHandler : ResultHandlerCompare { d0.store(d32tab); d1.store(d32tab + 16); - while (lt_mask) { - // find first non-zero - int j = __builtin_ctz(lt_mask); - lt_mask -= 1 << j; - T d = d32tab[j]; - if (C::cmp(idis[q], d)) { - idis[q] = d; - ids[q] = this->adjust_id(b, j); + if (this->sel != nullptr) { + while (lt_mask) { + // find first non-zero + int j = __builtin_ctz(lt_mask); + auto real_idx = this->adjust_id(b, j); + lt_mask -= 1 << j; + if (this->sel->is_member(real_idx)) { + T d = d32tab[j]; + if (C::cmp(idis[q], d)) { + idis[q] = d; + ids[q] = real_idx; + } + } + } + } + else { + while (lt_mask) { + // find first non-zero + int j = __builtin_ctz(lt_mask); + lt_mask -= 1 << j; + T d = d32tab[j]; + if (C::cmp(idis[q], d)) { + idis[q] = d; + ids[q] = this->adjust_id(b, j); + } } } } void end() { - for (int q = 0; q < this->nq; q++) { + for (size_t q = 0; q < this->nq; q++) { if (!normalizers) { dis[q] = idis[q]; } else { @@ -296,8 +330,8 @@ struct HeapHandler : ResultHandlerCompare { int64_t k; // number of results to keep - HeapHandler(size_t nq, size_t ntotal, int64_t k, float* dis, int64_t* ids) - : RHC(nq, ntotal), + HeapHandler(size_t nq, size_t ntotal, int64_t k, float* dis, int64_t* ids, const IDSelector* sel_in) + : RHC(nq, ntotal, sel_in), idis(nq * k), iids(nq * k), dis(dis), @@ -330,21 +364,38 @@ struct HeapHandler : ResultHandlerCompare { d0.store(d32tab); d1.store(d32tab + 16); - while (lt_mask) { - // find first non-zero - int j = __builtin_ctz(lt_mask); - lt_mask -= 1 << j; - T dis = d32tab[j]; - if (C::cmp(heap_dis[0], dis)) { - int64_t idx = this->adjust_id(b, j); - heap_pop(k, heap_dis, heap_ids); - heap_push(k, heap_dis, heap_ids, dis, idx); + if (this->sel != nullptr) { + while (lt_mask) { + // find first non-zero + int j = __builtin_ctz(lt_mask); + auto real_idx = this->adjust_id(b, j); + lt_mask -= 1 << j; + if (this->sel->is_member(real_idx)) { + T dis = d32tab[j]; + if (C::cmp(heap_dis[0], dis)) { + heap_pop(k, heap_dis, heap_ids); + heap_push(k, heap_dis, heap_ids, dis, real_idx); + } + } + } + } + else { + while (lt_mask) { + // find first non-zero + int j = __builtin_ctz(lt_mask); + lt_mask -= 1 << j; + T dis = d32tab[j]; + if (C::cmp(heap_dis[0], dis)) { + int64_t idx = this->adjust_id(b, j); + heap_pop(k, heap_dis, heap_ids); + heap_push(k, heap_dis, heap_ids, dis, idx); + } } } } void end() override { - for (int q = 0; q < this->nq; q++) { + for (size_t q = 0; q < this->nq; q++) { T* heap_dis_in = idis.data() + q * k; TI* heap_ids_in = iids.data() + q * k; heap_reorder(k, heap_dis_in, heap_ids_in); @@ -393,8 +444,9 @@ struct ReservoirHandler : ResultHandlerCompare { size_t k, size_t cap, float* dis, - int64_t* ids) - : RHC(nq, ntotal), capacity((cap + 15) & ~15), dis(dis), ids(ids) { + int64_t* ids, + const IDSelector* sel_in) + : RHC(nq, ntotal, sel_in), capacity((cap + 15) & ~15), dis(dis), ids(ids) { assert(capacity % 16 == 0); all_ids.resize(nq * capacity); all_vals.resize(nq * capacity); @@ -423,12 +475,26 @@ struct ReservoirHandler : ResultHandlerCompare { d0.store(d32tab); d1.store(d32tab + 16); - while (lt_mask) { - // find first non-zero - int j = __builtin_ctz(lt_mask); - lt_mask -= 1 << j; - T dis = d32tab[j]; - res.add(dis, this->adjust_id(b, j)); + if (this->sel != nullptr) { + while (lt_mask) { + // find first non-zero + int j = __builtin_ctz(lt_mask); + auto real_idx = this->adjust_id(b, j); + lt_mask -= 1 << j; + if (this->sel->is_member(real_idx)) { + T dis = d32tab[j]; + res.add(dis, real_idx); + } + } + } + else { + while (lt_mask) { + // find first non-zero + int j = __builtin_ctz(lt_mask); + lt_mask -= 1 << j; + T dis = d32tab[j]; + res.add(dis, this->adjust_id(b, j)); + } } } @@ -439,7 +505,7 @@ struct ReservoirHandler : ResultHandlerCompare { CMin>::type; std::vector perm(reservoirs[0].n); - for (int q = 0; q < reservoirs.size(); q++) { + for (size_t q = 0; q < reservoirs.size(); q++) { ReservoirTopN& res = reservoirs[q]; size_t n = res.n; @@ -454,14 +520,14 @@ struct ReservoirHandler : ResultHandlerCompare { one_a = 1 / normalizers[2 * q]; b = normalizers[2 * q + 1]; } - for (int i = 0; i < res.i; i++) { + for (size_t i = 0; i < res.i; i++) { perm[i] = i; } // indirect sort of result arrays std::sort(perm.begin(), perm.begin() + res.i, [&res](int i, int j) { return C::cmp(res.vals[j], res.vals[i]); }); - for (int i = 0; i < res.i; i++) { + for (size_t i = 0; i < res.i; i++) { heap_dis[i] = res.vals[perm[i]] * one_a + b; heap_ids[i] = res.ids[perm[i]]; } @@ -499,8 +565,8 @@ struct RangeHandler : ResultHandlerCompare { }; std::vector triplets; - RangeHandler(RangeSearchResult& rres, float radius, size_t ntotal) - : RHC(rres.nq, ntotal), rres(rres), radius(radius) { + RangeHandler(RangeSearchResult& rres, float radius, size_t ntotal, const IDSelector* sel_in) + : RHC(rres.nq, ntotal, sel_in), rres(rres), radius(radius) { thresholds.resize(nq); n_per_query.resize(nq + 1); } @@ -528,13 +594,28 @@ struct RangeHandler : ResultHandlerCompare { d0.store(d32tab); d1.store(d32tab + 16); - while (lt_mask) { - // find first non-zero - int j = __builtin_ctz(lt_mask); - lt_mask -= 1 << j; - T dis = d32tab[j]; - n_per_query[q]++; - triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis}); + if (this->sel != nullptr) { + while (lt_mask) { + // find first non-zero + int j = __builtin_ctz(lt_mask); + lt_mask -= 1 << j; + + auto real_idx = this->adjust_id(b, j); + if (this->sel->is_member(real_idx)) { + T dis = d32tab[j]; + n_per_query[q]++; + triplets.push_back({idx_t(q + q0), real_idx, dis}); + } + } + } else { + while (lt_mask) { + // find first non-zero + int j = __builtin_ctz(lt_mask); + lt_mask -= 1 << j; + T dis = d32tab[j]; + n_per_query[q]++; + triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis}); + } } } @@ -578,8 +659,9 @@ struct PartialRangeHandler : RangeHandler { float radius, size_t ntotal, size_t q0, - size_t q1) - : RangeHandler(*pres.res, radius, ntotal), + size_t q1, + const IDSelector* sel_in) + : RangeHandler(*pres.res, radius, ntotal, sel_in), pres(pres) { nq = q1 - q0; this->q0 = q0; @@ -698,6 +780,7 @@ void dispatch_SIMDResultHanlder( FAISS_THROW_FMT("Unknown id size %d", res.sizeof_ids); } } + } // namespace simd_result_handlers } // namespace faiss From 1623c3f2860b84b550b64961bccbcc5ea6572635 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Thu, 21 Mar 2024 13:12:55 -0400 Subject: [PATCH 2/6] update formatting Signed-off-by: Alexandr Guzhva --- faiss/IVFlib.cpp | 6 +++-- faiss/IndexIVFFastScan.cpp | 29 ++++++++++++++-------- faiss/impl/simd_result_handlers.h | 40 +++++++++++++++++++++---------- 3 files changed, 51 insertions(+), 24 deletions(-) diff --git a/faiss/IVFlib.cpp b/faiss/IVFlib.cpp index d7dfbaf582..f2c975f4de 100644 --- a/faiss/IVFlib.cpp +++ b/faiss/IVFlib.cpp @@ -352,8 +352,10 @@ void search_with_parameters( const IndexIVF* index_ivf = dynamic_cast(index); FAISS_THROW_IF_NOT(index_ivf); - SearchParameters* quantizer_params = (params) ? params->quantizer_params : nullptr; - index_ivf->quantizer->search(n, x, params->nprobe, Dq.data(), Iq.data(), quantizer_params); + SearchParameters* quantizer_params = + (params) ? params->quantizer_params : nullptr; + index_ivf->quantizer->search( + n, x, params->nprobe, Dq.data(), Iq.data(), quantizer_params); if (nb_dis_ptr) { *nb_dis_ptr = count_ndis(index_ivf, n * params->nprobe, Iq.data()); diff --git a/faiss/IndexIVFFastScan.cpp b/faiss/IndexIVFFastScan.cpp index e454e9140d..f913ab5f58 100644 --- a/faiss/IndexIVFFastScan.cpp +++ b/faiss/IndexIVFFastScan.cpp @@ -417,14 +417,19 @@ struct CoarseQuantizedWithBuffer : CoarseQuantized { std::vector dis_buffer; void quantize( - const Index* quantizer, - idx_t n, - const float* x, - const SearchParameters* quantizer_params - ) { + const Index* quantizer, + idx_t n, + const float* x, + const SearchParameters* quantizer_params) { dis_buffer.resize(nprobe * n); ids_buffer.resize(nprobe * n); - quantizer->search(n, x, nprobe, dis_buffer.data(), ids_buffer.data(), quantizer_params); + quantizer->search( + n, + x, + nprobe, + dis_buffer.data(), + ids_buffer.data(), + quantizer_params); dis = dis_buffer.data(); ids = ids_buffer.data(); } @@ -440,7 +445,10 @@ struct CoarseQuantizedSlice : CoarseQuantizedWithBuffer { } } - void quantize_slice(const Index* quantizer, const float* x, const SearchParameters* quantizer_params) { + void quantize_slice( + const Index* quantizer, + const float* x, + const SearchParameters* quantizer_params) { quantize(quantizer, i1 - i0, x + quantizer->d * i0, quantizer_params); } }; @@ -483,7 +491,7 @@ void IndexIVFFastScan::search_dispatch_implem( const idx_t nprobe = params ? params->nprobe : this->nprobe; const IDSelector* sel = (params) ? params->sel : nullptr; const SearchParameters* quantizer_params = - params ? params->quantizer_params : nullptr; + params ? params->quantizer_params : nullptr; bool is_max = !is_similarity_metric(metric_type); using RH = SIMDResultHandlerToFloat; @@ -585,7 +593,8 @@ void IndexIVFFastScan::search_dispatch_implem( if (impl == 14 || impl == 15) { // this might require slicing if there are too // many queries (for now we keep this simple) - search_implem_14(n, x, k, distances, labels, cq, impl, scaler, params); + search_implem_14( + n, x, k, distances, labels, cq, impl, scaler, params); } else { #pragma omp parallel for reduction(+ : ndis, nlist_visited) for (int slice = 0; slice < nslice; slice++) { @@ -632,7 +641,7 @@ void IndexIVFFastScan::range_search_dispatch_implem( // const idx_t nprobe = params ? params->nprobe : this->nprobe; const IDSelector* sel = (params) ? params->sel : nullptr; const SearchParameters* quantizer_params = - params ? params->quantizer_params : nullptr; + params ? params->quantizer_params : nullptr; bool is_max = !is_similarity_metric(metric_type); diff --git a/faiss/impl/simd_result_handlers.h b/faiss/impl/simd_result_handlers.h index d5c62b0304..3283be4b72 100644 --- a/faiss/impl/simd_result_handlers.h +++ b/faiss/impl/simd_result_handlers.h @@ -58,7 +58,8 @@ struct SIMDResultHandlerToFloat : SIMDResultHandler { nullptr; // table of biases to add to each query (for IVF L2 search) const float* normalizers = nullptr; // size 2 * nq, to convert - SIMDResultHandlerToFloat(size_t nq, size_t ntotal) : nq(nq), ntotal(ntotal) {} + SIMDResultHandlerToFloat(size_t nq, size_t ntotal) + : nq(nq), ntotal(ntotal) {} virtual void begin(const float* norms) { normalizers = norms; @@ -249,7 +250,12 @@ struct SingleResultHandler : ResultHandlerCompare { float* dis; int64_t* ids; - SingleResultHandler(size_t nq, size_t ntotal, float* dis, int64_t* ids, const IDSelector* sel_in) + SingleResultHandler( + size_t nq, + size_t ntotal, + float* dis, + int64_t* ids, + const IDSelector* sel_in) : RHC(nq, ntotal, sel_in), idis(nq), dis(dis), ids(ids) { for (size_t i = 0; i < nq; i++) { ids[i] = -1; @@ -287,8 +293,7 @@ struct SingleResultHandler : ResultHandlerCompare { } } } - } - else { + } else { while (lt_mask) { // find first non-zero int j = __builtin_ctz(lt_mask); @@ -330,7 +335,13 @@ struct HeapHandler : ResultHandlerCompare { int64_t k; // number of results to keep - HeapHandler(size_t nq, size_t ntotal, int64_t k, float* dis, int64_t* ids, const IDSelector* sel_in) + HeapHandler( + size_t nq, + size_t ntotal, + int64_t k, + float* dis, + int64_t* ids, + const IDSelector* sel_in) : RHC(nq, ntotal, sel_in), idis(nq * k), iids(nq * k), @@ -378,8 +389,7 @@ struct HeapHandler : ResultHandlerCompare { } } } - } - else { + } else { while (lt_mask) { // find first non-zero int j = __builtin_ctz(lt_mask); @@ -444,9 +454,12 @@ struct ReservoirHandler : ResultHandlerCompare { size_t k, size_t cap, float* dis, - int64_t* ids, + int64_t* ids, const IDSelector* sel_in) - : RHC(nq, ntotal, sel_in), capacity((cap + 15) & ~15), dis(dis), ids(ids) { + : RHC(nq, ntotal, sel_in), + capacity((cap + 15) & ~15), + dis(dis), + ids(ids) { assert(capacity % 16 == 0); all_ids.resize(nq * capacity); all_vals.resize(nq * capacity); @@ -486,8 +499,7 @@ struct ReservoirHandler : ResultHandlerCompare { res.add(dis, real_idx); } } - } - else { + } else { while (lt_mask) { // find first non-zero int j = __builtin_ctz(lt_mask); @@ -565,7 +577,11 @@ struct RangeHandler : ResultHandlerCompare { }; std::vector triplets; - RangeHandler(RangeSearchResult& rres, float radius, size_t ntotal, const IDSelector* sel_in) + RangeHandler( + RangeSearchResult& rres, + float radius, + size_t ntotal, + const IDSelector* sel_in) : RHC(rres.nq, ntotal, sel_in), rres(rres), radius(radius) { thresholds.resize(nq); n_per_query.resize(nq + 1); From d7b687706c9716575a28fd89ab66bb8a01f93c34 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Thu, 21 Mar 2024 15:07:46 -0400 Subject: [PATCH 3/6] Fix a problem with MSVC Signed-off-by: Alexandr Guzhva --- faiss/IndexIVFFastScan.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/faiss/IndexIVFFastScan.cpp b/faiss/IndexIVFFastScan.cpp index f913ab5f58..a37800164a 100644 --- a/faiss/IndexIVFFastScan.cpp +++ b/faiss/IndexIVFFastScan.cpp @@ -270,8 +270,9 @@ void IndexIVFFastScan::compute_LUT_uint8( biases.resize(n * nprobe); } + // OMP for MSVC requires i to have signed integral type #pragma omp parallel for if (n > 100) - for (size_t i = 0; i < n; i++) { + for (int64_t i = 0; i < n; i++) { const float* t_in = dis_tables_float.get() + i * dim123; const float* b_in = nullptr; uint8_t* t_out = dis_tables.get() + i * dim123_2; From 82f2d2a00302eb085fd80e1e50bb96d1f7798c6e Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Thu, 28 Mar 2024 10:38:32 -0400 Subject: [PATCH 4/6] code cleanup Signed-off-by: Alexandr Guzhva --- faiss/impl/simd_result_handlers.h | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/faiss/impl/simd_result_handlers.h b/faiss/impl/simd_result_handlers.h index 3283be4b72..03a38b3286 100644 --- a/faiss/impl/simd_result_handlers.h +++ b/faiss/impl/simd_result_handlers.h @@ -221,20 +221,6 @@ struct ResultHandlerCompare : SIMDResultHandlerToFloat { return lt_mask; } - uint32_t get_lt_mask_for_range_search(size_t b) { - uint32_t lt_mask = 0xffffffff; - - uint64_t idx = j0 + b * 32; - if (idx + 32 > ntotal) { - if (idx >= ntotal) { - return 0; - } - int nbit = (ntotal - idx); - lt_mask &= (uint32_t(1) << nbit) - 1; - } - return lt_mask; - } - virtual ~ResultHandlerCompare() {} }; @@ -384,8 +370,7 @@ struct HeapHandler : ResultHandlerCompare { if (this->sel->is_member(real_idx)) { T dis = d32tab[j]; if (C::cmp(heap_dis[0], dis)) { - heap_pop(k, heap_dis, heap_ids); - heap_push(k, heap_dis, heap_ids, dis, real_idx); + heap_replace_top(k, heap_dis, heap_ids, dis, real_idx); } } } @@ -397,8 +382,7 @@ struct HeapHandler : ResultHandlerCompare { T dis = d32tab[j]; if (C::cmp(heap_dis[0], dis)) { int64_t idx = this->adjust_id(b, j); - heap_pop(k, heap_dis, heap_ids); - heap_push(k, heap_dis, heap_ids, dis, idx); + heap_replace_top(k, heap_dis, heap_ids, dis, idx); } } } From af9c83d1660a9511f65f7818c7dcb0d6e752c557 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Thu, 28 Mar 2024 10:46:08 -0400 Subject: [PATCH 5/6] fix formatting Signed-off-by: Alexandr Guzhva --- faiss/impl/simd_result_handlers.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/faiss/impl/simd_result_handlers.h b/faiss/impl/simd_result_handlers.h index 03a38b3286..2fa18fa340 100644 --- a/faiss/impl/simd_result_handlers.h +++ b/faiss/impl/simd_result_handlers.h @@ -370,7 +370,8 @@ struct HeapHandler : ResultHandlerCompare { if (this->sel->is_member(real_idx)) { T dis = d32tab[j]; if (C::cmp(heap_dis[0], dis)) { - heap_replace_top(k, heap_dis, heap_ids, dis, real_idx); + heap_replace_top( + k, heap_dis, heap_ids, dis, real_idx); } } } From 6f6d26eba6d30f4a11b8f59dfbfca256ea147df2 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 2 Apr 2024 10:12:32 -0400 Subject: [PATCH 6/6] add a missing prefetch Signed-off-by: Alexandr Guzhva --- faiss/IndexIVFFastScan.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/faiss/IndexIVFFastScan.cpp b/faiss/IndexIVFFastScan.cpp index a37800164a..19828753d2 100644 --- a/faiss/IndexIVFFastScan.cpp +++ b/faiss/IndexIVFFastScan.cpp @@ -530,6 +530,7 @@ void IndexIVFFastScan::search_dispatch_implem( // sliced over threads (then it is more efficient to have each thread do // its own coarse quantization) cq.quantize(quantizer, n, x, quantizer_params); + invlists->prefetch_lists(cq.ids, n * cq.nprobe); } if (impl == 1) { @@ -672,6 +673,7 @@ void IndexIVFFastScan::range_search_dispatch_implem( if (!multiple_threads && !cq.done()) { cq.quantize(quantizer, n, x, quantizer_params); + invlists->prefetch_lists(cq.ids, n * cq.nprobe); } size_t ndis = 0, nlist_visited = 0;