diff --git a/contrib/inspect_tools.py b/contrib/inspect_tools.py index cc22ff5368..0aef5ac96d 100644 --- a/contrib/inspect_tools.py +++ b/contrib/inspect_tools.py @@ -98,6 +98,12 @@ def get_flat_data(index): return xb.reshape(index.ntotal, index.d) +def get_flat_codes(index_flat): + """ get the codes from an indexFlatCodes as an array """ + return faiss.vector_to_array(index_flat.codes).reshape( + index_flat.ntotal, index_flat.code_size) + + def get_NSG_neighbors(nsg): """ get the neighbor list for the vectors stored in the NSG structure, as a N-by-K matrix of indices """ diff --git a/faiss/IndexFlatCodes.cpp b/faiss/IndexFlatCodes.cpp index caff90ff9c..919a51d2f8 100644 --- a/faiss/IndexFlatCodes.cpp +++ b/faiss/IndexFlatCodes.cpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include namespace faiss { @@ -70,11 +72,6 @@ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const { reconstruct_n(key, 1, recons); } -FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer() - const { - FAISS_THROW_MSG("not implemented"); -} - void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const { // minimal sanity checks const IndexFlatCodes* other = @@ -114,4 +111,161 @@ void IndexFlatCodes::permute_entries(const idx_t* perm) { std::swap(codes, new_codes); } +namespace { + +template +struct GenericFlatCodesDistanceComputer : FlatCodesDistanceComputer { + const IndexFlatCodes& codec; + const VD vd; + // temp buffers + std::vector code_buffer; + std::vector vec_buffer; + const float* query = nullptr; + + GenericFlatCodesDistanceComputer(const IndexFlatCodes* codec, const VD& vd) + : FlatCodesDistanceComputer(codec->codes.data(), codec->code_size), + codec(*codec), + vd(vd), + code_buffer(codec->code_size * 4), + vec_buffer(codec->d * 4) {} + + void set_query(const float* x) override { + query = x; + } + + float operator()(idx_t i) override { + codec.sa_decode(1, codes + i * code_size, vec_buffer.data()); + return vd(query, vec_buffer.data()); + } + + float distance_to_code(const uint8_t* code) override { + codec.sa_decode(1, code, vec_buffer.data()); + return vd(query, vec_buffer.data()); + } + + float symmetric_dis(idx_t i, idx_t j) override { + codec.sa_decode(1, codes + i * code_size, vec_buffer.data()); + codec.sa_decode(1, codes + j * code_size, vec_buffer.data() + vd.d); + return vd(vec_buffer.data(), vec_buffer.data() + vd.d); + } + + void distances_batch_4( + const idx_t idx0, + const idx_t idx1, + const idx_t idx2, + const idx_t idx3, + float& dis0, + float& dis1, + float& dis2, + float& dis3) override { + uint8_t* cp = code_buffer.data(); + for (idx_t i : {idx0, idx1, idx2, idx3}) { + memcpy(cp, codes + i * code_size, code_size); + cp += code_size; + } + // potential benefit is if batch decoding is more efficient than 1 by 1 + // decoding + codec.sa_decode(4, code_buffer.data(), vec_buffer.data()); + dis0 = vd(query, vec_buffer.data()); + dis1 = vd(query, vec_buffer.data() + vd.d); + dis2 = vd(query, vec_buffer.data() + 2 * vd.d); + dis3 = vd(query, vec_buffer.data() + 3 * vd.d); + } +}; + +struct Run_get_distance_computer { + using T = FlatCodesDistanceComputer*; + + template + FlatCodesDistanceComputer* f(const VD& vd, const IndexFlatCodes* codec) { + return new GenericFlatCodesDistanceComputer(codec, vd); + } +}; + +template +struct Run_search_with_decompress { + using T = void; + + template + void f(VectorDistance& vd, + const IndexFlatCodes* index_ptr, + const float* xq, + BlockResultHandler& res) { + // Note that there seems to be a clang (?) bug that "sometimes" passes + // the const Index & parameters by value, so to be on the safe side, + // it's better to use pointers. + const IndexFlatCodes& index = *index_ptr; + size_t ntotal = index.ntotal; + using SingleResultHandler = + typename BlockResultHandler::SingleResultHandler; + using DC = GenericFlatCodesDistanceComputer; +#pragma omp parallel // if (res.nq > 100) + { + std::unique_ptr dc(new DC(&index, vd)); + SingleResultHandler resi(res); +#pragma omp for + for (int64_t q = 0; q < res.nq; q++) { + resi.begin(q); + dc->set_query(xq + vd.d * q); + for (size_t i = 0; i < ntotal; i++) { + if (res.is_in_selection(i)) { + float dis = (*dc)(i); + resi.add_result(dis, i); + } + } + resi.end(); + } + } + } +}; + +struct Run_search_with_decompress_res { + using T = void; + + template + void f(ResultHandler& res, const IndexFlatCodes* index, const float* xq) { + Run_search_with_decompress r; + dispatch_VectorDistance( + index->d, + index->metric_type, + index->metric_arg, + r, + index, + xq, + res); + } +}; + +} // anonymous namespace + +FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer() + const { + Run_get_distance_computer r; + return dispatch_VectorDistance(d, metric_type, metric_arg, r, this); +} + +void IndexFlatCodes::search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params) const { + Run_search_with_decompress_res r; + const IDSelector* sel = params ? params->sel : nullptr; + dispatch_knn_ResultHandler( + n, distances, labels, k, metric_type, sel, r, this, x); +} + +void IndexFlatCodes::range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params) const { + const IDSelector* sel = params ? params->sel : nullptr; + Run_search_with_decompress_res r; + dispatch_range_ResultHandler(result, radius, metric_type, sel, r, this, x); +} + } // namespace faiss diff --git a/faiss/IndexFlatCodes.h b/faiss/IndexFlatCodes.h index cd43e47ecc..787e2bb2d3 100644 --- a/faiss/IndexFlatCodes.h +++ b/faiss/IndexFlatCodes.h @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - #pragma once #include @@ -45,13 +43,32 @@ struct IndexFlatCodes : Index { * different from the usual ones: the new ids are shifted */ size_t remove_ids(const IDSelector& sel) override; - /** a FlatCodesDistanceComputer offers a distance_to_code method */ + /** a FlatCodesDistanceComputer offers a distance_to_code method + * + * The default implementation explicitly decodes the vector with sa_decode. + */ virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const; DistanceComputer* get_distance_computer() const override { return get_FlatCodesDistanceComputer(); } + /** Search implemented by decoding */ + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params = nullptr) const override; + + void range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params = nullptr) const override; + // returns a new instance of a CodePacker CodePacker* get_CodePacker() const; diff --git a/faiss/IndexLattice.cpp b/faiss/IndexLattice.cpp index 4e0448e299..4da9c9ff61 100644 --- a/faiss/IndexLattice.cpp +++ b/faiss/IndexLattice.cpp @@ -15,7 +15,7 @@ namespace faiss { IndexLattice::IndexLattice(idx_t d, int nsq, int scale_nbit, int r2) - : Index(d), + : IndexFlatCodes(0, d, METRIC_L2), nsq(nsq), dsq(d / nsq), zn_sphere_codec(dsq, r2), @@ -114,22 +114,4 @@ void IndexLattice::sa_decode(idx_t n, const uint8_t* codes, float* x) const { } } -void IndexLattice::add(idx_t, const float*) { - FAISS_THROW_MSG("not implemented"); -} - -void IndexLattice::search( - idx_t, - const float*, - idx_t, - float*, - idx_t*, - const SearchParameters*) const { - FAISS_THROW_MSG("not implemented"); -} - -void IndexLattice::reset() { - FAISS_THROW_MSG("not implemented"); -} - } // namespace faiss diff --git a/faiss/IndexLattice.h b/faiss/IndexLattice.h index a9eb62b6d9..f814c4aff1 100644 --- a/faiss/IndexLattice.h +++ b/faiss/IndexLattice.h @@ -5,21 +5,18 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - -#ifndef FAISS_INDEX_LATTICE_H -#define FAISS_INDEX_LATTICE_H +#pragma once #include -#include +#include #include namespace faiss { /** Index that encodes a vector with a series of Zn lattice quantizers */ -struct IndexLattice : Index { +struct IndexLattice : IndexFlatCodes { /// number of sub-vectors int nsq; /// dimension of sub-vectors @@ -30,8 +27,6 @@ struct IndexLattice : Index { /// nb bits used to encode the scale, per subvector int scale_nbit, lattice_nbit; - /// total, in bytes - size_t code_size; /// mins and maxes of the vector norms, per subquantizer std::vector trained; @@ -46,20 +41,6 @@ struct IndexLattice : Index { void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override; void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override; - - /// not implemented - void add(idx_t n, const float* x) override; - void search( - idx_t n, - const float* x, - idx_t k, - float* distances, - idx_t* labels, - const SearchParameters* params = nullptr) const override; - - void reset() override; }; } // namespace faiss - -#endif diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h index 713fe8e49f..511af10f79 100644 --- a/faiss/impl/ResultHandler.h +++ b/faiss/impl/ResultHandler.h @@ -13,8 +13,10 @@ #include #include +#include #include #include +#include #include namespace faiss { @@ -26,16 +28,21 @@ namespace faiss { * - by instanciating a SingleResultHandler that tracks results for a single * query * - with begin_multiple/add_results/end_multiple calls where a whole block of - * resutls is submitted + * results is submitted * All classes are templated on C which to define wheter the min or the max of - * results is to be kept. + * results is to be kept, and on sel, so that the codepaths for with / without + * selector can be separated at compile time. *****************************************************************/ -template +template struct BlockResultHandler { size_t nq; // number of queries for which we search + const IDSelector* sel; - explicit BlockResultHandler(size_t nq) : nq(nq) {} + explicit BlockResultHandler(size_t nq, const IDSelector* sel = nullptr) + : nq(nq), sel(sel) { + assert(!use_sel || sel); + } // currently handled query range size_t i0 = 0, i1 = 0; @@ -53,13 +60,17 @@ struct BlockResultHandler { virtual void end_multiple() {} virtual ~BlockResultHandler() {} + + bool is_in_selection(idx_t i) const { + return !use_sel || sel->is_member(i); + } }; // handler for a single query template struct ResultHandler { // if not better than threshold, then not necessary to call add_result - typename C::T threshold = 0; + typename C::T threshold = C::neutral(); // return whether threshold was updated virtual bool add_result(typename C::T dis, typename C::TI idx) = 0; @@ -73,20 +84,26 @@ struct ResultHandler { * some temporary data in memory. *****************************************************************/ -template -struct Top1BlockResultHandler : BlockResultHandler { +template +struct Top1BlockResultHandler : BlockResultHandler { using T = typename C::T; using TI = typename C::TI; - using BlockResultHandler::i0; - using BlockResultHandler::i1; + using BlockResultHandler::i0; + using BlockResultHandler::i1; // contains exactly nq elements T* dis_tab; // contains exactly nq elements TI* ids_tab; - Top1BlockResultHandler(size_t nq, T* dis_tab, TI* ids_tab) - : BlockResultHandler(nq), dis_tab(dis_tab), ids_tab(ids_tab) {} + Top1BlockResultHandler( + size_t nq, + T* dis_tab, + TI* ids_tab, + const IDSelector* sel = nullptr) + : BlockResultHandler(nq, sel), + dis_tab(dis_tab), + ids_tab(ids_tab) {} struct SingleResultHandler : ResultHandler { Top1BlockResultHandler& hr; @@ -165,12 +182,12 @@ struct Top1BlockResultHandler : BlockResultHandler { * Heap based result handler *****************************************************************/ -template -struct HeapBlockResultHandler : BlockResultHandler { +template +struct HeapBlockResultHandler : BlockResultHandler { using T = typename C::T; using TI = typename C::TI; - using BlockResultHandler::i0; - using BlockResultHandler::i1; + using BlockResultHandler::i0; + using BlockResultHandler::i1; T* heap_dis_tab; TI* heap_ids_tab; @@ -181,8 +198,9 @@ struct HeapBlockResultHandler : BlockResultHandler { size_t nq, T* heap_dis_tab, TI* heap_ids_tab, - size_t k) - : BlockResultHandler(nq), + size_t k, + const IDSelector* sel = nullptr) + : BlockResultHandler(nq, sel), heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k) {} @@ -347,12 +365,12 @@ struct ReservoirTopN : ResultHandler { } }; -template -struct ReservoirBlockResultHandler : BlockResultHandler { +template +struct ReservoirBlockResultHandler : BlockResultHandler { using T = typename C::T; using TI = typename C::TI; - using BlockResultHandler::i0; - using BlockResultHandler::i1; + using BlockResultHandler::i0; + using BlockResultHandler::i1; T* heap_dis_tab; TI* heap_ids_tab; @@ -364,8 +382,9 @@ struct ReservoirBlockResultHandler : BlockResultHandler { size_t nq, T* heap_dis_tab, TI* heap_ids_tab, - size_t k) - : BlockResultHandler(nq), + size_t k, + const IDSelector* sel = nullptr) + : BlockResultHandler(nq, sel), heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k) { @@ -460,18 +479,23 @@ struct ReservoirBlockResultHandler : BlockResultHandler { * Result handler for range searches *****************************************************************/ -template -struct RangeSearchBlockResultHandler : BlockResultHandler { +template +struct RangeSearchBlockResultHandler : BlockResultHandler { using T = typename C::T; using TI = typename C::TI; - using BlockResultHandler::i0; - using BlockResultHandler::i1; + using BlockResultHandler::i0; + using BlockResultHandler::i1; RangeSearchResult* res; T radius; - RangeSearchBlockResultHandler(RangeSearchResult* res, float radius) - : BlockResultHandler(res->nq), res(res), radius(radius) {} + RangeSearchBlockResultHandler( + RangeSearchResult* res, + float radius, + const IDSelector* sel = nullptr) + : BlockResultHandler(res->nq, sel), + res(res), + radius(radius) {} /****************************************************** * API for 1 result at a time (each SingleResultHandler is @@ -582,4 +606,81 @@ struct RangeSearchBlockResultHandler : BlockResultHandler { } }; +/***************************************************************** + * Dispatcher function to choose the right knn result handler depending on k + *****************************************************************/ + +// declared in distances.cpp +FAISS_API extern int distance_compute_min_k_reservoir; + +template +typename Consumer::T dispatch_knn_ResultHandler( + size_t nx, + float* vals, + int64_t* ids, + size_t k, + MetricType metric, + const IDSelector* sel, + Consumer& consumer, + Types... args) { +#define DISPATCH_C_SEL(C, use_sel) \ + if (k == 1) { \ + Top1BlockResultHandler res(nx, vals, ids, sel); \ + return consumer.template f<>(res, args...); \ + } else if (k < distance_compute_min_k_reservoir) { \ + HeapBlockResultHandler res(nx, vals, ids, k, sel); \ + return consumer.template f<>(res, args...); \ + } else { \ + ReservoirBlockResultHandler res(nx, vals, ids, k, sel); \ + return consumer.template f<>(res, args...); \ + } + + if (is_similarity_metric(metric)) { + using C = CMin; + if (sel) { + DISPATCH_C_SEL(C, true); + } else { + DISPATCH_C_SEL(C, false); + } + } else { + using C = CMax; + if (sel) { + DISPATCH_C_SEL(C, true); + } else { + DISPATCH_C_SEL(C, false); + } + } +#undef DISPATCH_C_SEL +} + +template +typename Consumer::T dispatch_range_ResultHandler( + RangeSearchResult* res, + float radius, + MetricType metric, + const IDSelector* sel, + Consumer& consumer, + Types... args) { +#define DISPATCH_C_SEL(C, use_sel) \ + RangeSearchBlockResultHandler resb(res, radius, sel); \ + return consumer.template f<>(resb, args...); + + if (is_similarity_metric(metric)) { + using C = CMin; + if (sel) { + DISPATCH_C_SEL(C, true); + } else { + DISPATCH_C_SEL(C, false); + } + } else { + using C = CMax; + if (sel) { + DISPATCH_C_SEL(C, true); + } else { + DISPATCH_C_SEL(C, false); + } + } +#undef DISPATCH_C_SEL +} + } // namespace faiss diff --git a/faiss/utils/distances.cpp b/faiss/utils/distances.cpp index 74b56bcc87..1506bee5cf 100644 --- a/faiss/utils/distances.cpp +++ b/faiss/utils/distances.cpp @@ -130,21 +130,18 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) { namespace { /* Find the nearest neighbors for nx queries in a set of ny vectors */ -template +template void exhaustive_inner_product_seq( const float* x, const float* y, size_t d, size_t nx, size_t ny, - BlockResultHandler& res, - const IDSelector* sel = nullptr) { + BlockResultHandler& res) { using SingleResultHandler = typename BlockResultHandler::SingleResultHandler; [[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads()); - FAISS_ASSERT(use_sel == (sel != nullptr)); - #pragma omp parallel num_threads(nt) { SingleResultHandler resi(res); @@ -156,7 +153,7 @@ void exhaustive_inner_product_seq( resi.begin(i); for (size_t j = 0; j < ny; j++, y_j += d) { - if (use_sel && !sel->is_member(j)) { + if (!res.is_in_selection(j)) { continue; } float ip = fvec_inner_product(x_i, y_j, d); @@ -167,21 +164,18 @@ void exhaustive_inner_product_seq( } } -template +template void exhaustive_L2sqr_seq( const float* x, const float* y, size_t d, size_t nx, size_t ny, - BlockResultHandler& res, - const IDSelector* sel = nullptr) { + BlockResultHandler& res) { using SingleResultHandler = typename BlockResultHandler::SingleResultHandler; [[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads()); - FAISS_ASSERT(use_sel == (sel != nullptr)); - #pragma omp parallel num_threads(nt) { SingleResultHandler resi(res); @@ -191,7 +185,7 @@ void exhaustive_L2sqr_seq( const float* y_j = y; resi.begin(i); for (size_t j = 0; j < ny; j++, y_j += d) { - if (use_sel && !sel->is_member(j)) { + if (!res.is_in_selection(j)) { continue; } float disij = fvec_L2sqr(x_i, y_j, d); @@ -326,6 +320,9 @@ void exhaustive_L2sqr_blas_default_impl( float ip = *ip_line; float dis = x_norms[i] + y_norms[j] - 2 * ip; + if (!res.is_in_selection(j)) { + dis = HUGE_VALF; + } // negative values can occur for identical vectors // due to roundoff errors if (dis < 0) @@ -601,44 +598,40 @@ void exhaustive_L2sqr_blas>>( #endif } -template -void knn_L2sqr_select( - const float* x, - const float* y, - size_t d, - size_t nx, - size_t ny, - BlockResultHandler& res, - const float* y_norm2, - const IDSelector* sel) { - if (sel) { - exhaustive_L2sqr_seq( - x, y, d, nx, ny, res, sel); - } else if (nx < distance_compute_blas_threshold) { - exhaustive_L2sqr_seq(x, y, d, nx, ny, res); - } else { - exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2); +struct Run_search_inner_product { + using T = void; + template + void f(BlockResultHandler& res, + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny) { + if (res.sel || nx < distance_compute_blas_threshold) { + exhaustive_inner_product_seq(x, y, d, nx, ny, res); + } else { + exhaustive_inner_product_blas(x, y, d, nx, ny, res); + } } -} - -template -void knn_inner_product_select( - const float* x, - const float* y, - size_t d, - size_t nx, - size_t ny, - BlockResultHandler& res, - const IDSelector* sel) { - if (sel) { - exhaustive_inner_product_seq( - x, y, d, nx, ny, res, sel); - } else if (nx < distance_compute_blas_threshold) { - exhaustive_inner_product_seq(x, y, d, nx, ny, res); - } else { - exhaustive_inner_product_blas(x, y, d, nx, ny, res); +}; + +struct Run_search_L2sqr { + using T = void; + template + void f(BlockResultHandler& res, + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + const float* y_norm2) { + if (res.sel || nx < distance_compute_blas_threshold) { + exhaustive_L2sqr_seq(x, y, d, nx, ny, res); + } else { + exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2); + } } -} +}; } // anonymous namespace @@ -675,16 +668,9 @@ void knn_inner_product( return; } - if (k == 1) { - Top1BlockResultHandler> res(nx, vals, ids); - knn_inner_product_select(x, y, d, nx, ny, res, sel); - } else if (k < distance_compute_min_k_reservoir) { - HeapBlockResultHandler> res(nx, vals, ids, k); - knn_inner_product_select(x, y, d, nx, ny, res, sel); - } else { - ReservoirBlockResultHandler> res(nx, vals, ids, k); - knn_inner_product_select(x, y, d, nx, ny, res, sel); - } + Run_search_inner_product r; + dispatch_knn_ResultHandler( + nx, vals, ids, k, METRIC_INNER_PRODUCT, sel, r, x, y, d, nx, ny); if (imin != 0) { for (size_t i = 0; i < nx * k; i++) { @@ -730,16 +716,11 @@ void knn_L2sqr( knn_L2sqr_by_idx(x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0); return; } - if (k == 1) { - Top1BlockResultHandler> res(nx, vals, ids); - knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel); - } else if (k < distance_compute_min_k_reservoir) { - HeapBlockResultHandler> res(nx, vals, ids, k); - knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel); - } else { - ReservoirBlockResultHandler> res(nx, vals, ids, k); - knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel); - } + + Run_search_L2sqr r; + dispatch_knn_ResultHandler( + nx, vals, ids, k, METRIC_L2, sel, r, x, y, d, nx, ny, y_norm2); + if (imin != 0) { for (size_t i = 0; i < nx * k; i++) { if (ids[i] >= 0) { @@ -766,6 +747,7 @@ void knn_L2sqr( * Range search ***************************************************************************/ +// TODO accept a y_norm2 as well void range_search_L2sqr( const float* x, const float* y, @@ -775,15 +757,9 @@ void range_search_L2sqr( float radius, RangeSearchResult* res, const IDSelector* sel) { - using RH = RangeSearchBlockResultHandler>; - RH resh(res, radius); - if (sel) { - exhaustive_L2sqr_seq(x, y, d, nx, ny, resh, sel); - } else if (nx < distance_compute_blas_threshold) { - exhaustive_L2sqr_seq(x, y, d, nx, ny, resh, sel); - } else { - exhaustive_L2sqr_blas(x, y, d, nx, ny, resh); - } + Run_search_L2sqr r; + dispatch_range_ResultHandler( + res, radius, METRIC_L2, sel, r, x, y, d, nx, ny, nullptr); } void range_search_inner_product( @@ -795,15 +771,9 @@ void range_search_inner_product( float radius, RangeSearchResult* res, const IDSelector* sel) { - using RH = RangeSearchBlockResultHandler>; - RH resh(res, radius); - if (sel) { - exhaustive_inner_product_seq(x, y, d, nx, ny, resh, sel); - } else if (nx < distance_compute_blas_threshold) { - exhaustive_inner_product_seq(x, y, d, nx, ny, resh); - } else { - exhaustive_inner_product_blas(x, y, d, nx, ny, resh); - } + Run_search_inner_product r; + dispatch_range_ResultHandler( + res, radius, METRIC_INNER_PRODUCT, sel, r, x, y, d, nx, ny); } /*************************************************************************** diff --git a/faiss/utils/extra_distances-inl.h b/faiss/utils/extra_distances-inl.h index 3171580f8c..b96c2d3136 100644 --- a/faiss/utils/extra_distances-inl.h +++ b/faiss/utils/extra_distances-inl.h @@ -162,4 +162,39 @@ inline float VectorDistance::operator()( return accu; } +/*************************************************************************** + * Dispatching function that takes a metric type and a consumer object + * the consumer object should contain a retun type T and a operation template + * function f() that is called to perform the operation. The first argument + * of the function is the VectorDistance object. The rest are passed in as is. + **************************************************************************/ + +template +typename Consumer::T dispatch_VectorDistance( + size_t d, + MetricType metric, + float metric_arg, + Consumer& consumer, + Types... args) { + switch (metric) { +#define DISPATCH_VD(mt) \ + case mt: { \ + VectorDistance vd = {d, metric_arg}; \ + return consumer.template f>(vd, args...); \ + } + DISPATCH_VD(METRIC_INNER_PRODUCT); + DISPATCH_VD(METRIC_L2); + DISPATCH_VD(METRIC_L1); + DISPATCH_VD(METRIC_Linf); + DISPATCH_VD(METRIC_Lp); + DISPATCH_VD(METRIC_Canberra); + DISPATCH_VD(METRIC_BrayCurtis); + DISPATCH_VD(METRIC_JensenShannon); + DISPATCH_VD(METRIC_Jaccard); + DISPATCH_VD(METRIC_NaNEuclidean); + DISPATCH_VD(METRIC_ABS_INNER_PRODUCT); + } +#undef DISPATCH_VD +} + } // namespace faiss diff --git a/faiss/utils/extra_distances.cpp b/faiss/utils/extra_distances.cpp index 407057e58e..69ee961c7f 100644 --- a/faiss/utils/extra_distances.cpp +++ b/faiss/utils/extra_distances.cpp @@ -26,72 +26,77 @@ namespace faiss { namespace { -template -void pairwise_extra_distances_template( - VD vd, - int64_t nq, - const float* xq, - int64_t nb, - const float* xb, - float* dis, - int64_t ldq, - int64_t ldb, - int64_t ldd) { +struct Run_pairwise_extra_distances { + using T = void; + + template + void f(VD vd, + int64_t nq, + const float* xq, + int64_t nb, + const float* xb, + float* dis, + int64_t ldq, + int64_t ldb, + int64_t ldd) { #pragma omp parallel for if (nq > 10) - for (int64_t i = 0; i < nq; i++) { - const float* xqi = xq + i * ldq; - const float* xbj = xb; - float* disi = dis + ldd * i; - - for (int64_t j = 0; j < nb; j++) { - disi[j] = vd(xqi, xbj); - xbj += ldb; + for (int64_t i = 0; i < nq; i++) { + const float* xqi = xq + i * ldq; + const float* xbj = xb; + float* disi = dis + ldd * i; + + for (int64_t j = 0; j < nb; j++) { + disi[j] = vd(xqi, xbj); + xbj += ldb; + } } } -} - -template -void knn_extra_metrics_template( - VD vd, - const float* x, - const float* y, - size_t nx, - size_t ny, - size_t k, - float* distances, - int64_t* labels) { - size_t d = vd.d; - using C = typename VD::C; - size_t check_period = InterruptCallback::get_period_hint(ny * d); - check_period *= omp_get_max_threads(); +}; - for (size_t i0 = 0; i0 < nx; i0 += check_period) { - size_t i1 = std::min(i0 + check_period, nx); +struct Run_knn_extra_metrics { + using T = void; + template + void f(VD vd, + const float* x, + const float* y, + size_t nx, + size_t ny, + size_t k, + float* distances, + int64_t* labels) { + size_t d = vd.d; + using C = typename VD::C; + size_t check_period = InterruptCallback::get_period_hint(ny * d); + check_period *= omp_get_max_threads(); + + for (size_t i0 = 0; i0 < nx; i0 += check_period) { + size_t i1 = std::min(i0 + check_period, nx); #pragma omp parallel for - for (int64_t i = i0; i < i1; i++) { - const float* x_i = x + i * d; - const float* y_j = y; - size_t j; - float* simi = distances + k * i; - int64_t* idxi = labels + k * i; - - // maxheap_heapify(k, simi, idxi); - heap_heapify(k, simi, idxi); - for (j = 0; j < ny; j++) { - float disij = vd(x_i, y_j); - - if (C::cmp(simi[0], disij)) { - heap_replace_top(k, simi, idxi, disij, j); + for (int64_t i = i0; i < i1; i++) { + const float* x_i = x + i * d; + const float* y_j = y; + size_t j; + float* simi = distances + k * i; + int64_t* idxi = labels + k * i; + + // maxheap_heapify(k, simi, idxi); + heap_heapify(k, simi, idxi); + for (j = 0; j < ny; j++) { + float disij = vd(x_i, y_j); + + if (C::cmp(simi[0], disij)) { + heap_replace_top(k, simi, idxi, disij, j); + } + y_j += d; } - y_j += d; + // maxheap_reorder(k, simi, idxi); + heap_reorder(k, simi, idxi); } - // maxheap_reorder(k, simi, idxi); - heap_reorder(k, simi, idxi); + InterruptCallback::check(); } - InterruptCallback::check(); } -} +}; template struct ExtraDistanceComputer : FlatCodesDistanceComputer { @@ -124,6 +129,19 @@ struct ExtraDistanceComputer : FlatCodesDistanceComputer { } }; +struct Run_get_distance_computer { + using T = FlatCodesDistanceComputer*; + + template + FlatCodesDistanceComputer* f( + VD vd, + const float* xb, + size_t nb, + const float* q = nullptr) { + return new ExtraDistanceComputer(vd, xb, nb, q); + } +}; + } // anonymous namespace void pairwise_extra_distances( @@ -147,28 +165,9 @@ void pairwise_extra_distances( if (ldd == -1) ldd = nb; - switch (mt) { -#define HANDLE_VAR(kw) \ - case METRIC_##kw: { \ - VectorDistance vd = {(size_t)d, metric_arg}; \ - pairwise_extra_distances_template( \ - vd, nq, xq, nb, xb, dis, ldq, ldb, ldd); \ - break; \ - } - HANDLE_VAR(L2); - HANDLE_VAR(L1); - HANDLE_VAR(Linf); - HANDLE_VAR(Canberra); - HANDLE_VAR(BrayCurtis); - HANDLE_VAR(JensenShannon); - HANDLE_VAR(Lp); - HANDLE_VAR(Jaccard); - HANDLE_VAR(NaNEuclidean); - HANDLE_VAR(ABS_INNER_PRODUCT); -#undef HANDLE_VAR - default: - FAISS_THROW_MSG("metric type not implemented"); - } + Run_pairwise_extra_distances run; + dispatch_VectorDistance( + d, mt, metric_arg, run, nq, xq, nb, xb, dis, ldq, ldb, ldd); } void knn_extra_metrics( @@ -182,27 +181,9 @@ void knn_extra_metrics( size_t k, float* distances, int64_t* indexes) { - switch (mt) { -#define HANDLE_VAR(kw) \ - case METRIC_##kw: { \ - VectorDistance vd = {(size_t)d, metric_arg}; \ - knn_extra_metrics_template(vd, x, y, nx, ny, k, distances, indexes); \ - break; \ - } - HANDLE_VAR(L2); - HANDLE_VAR(L1); - HANDLE_VAR(Linf); - HANDLE_VAR(Canberra); - HANDLE_VAR(BrayCurtis); - HANDLE_VAR(JensenShannon); - HANDLE_VAR(Lp); - HANDLE_VAR(Jaccard); - HANDLE_VAR(NaNEuclidean); - HANDLE_VAR(ABS_INNER_PRODUCT); -#undef HANDLE_VAR - default: - FAISS_THROW_MSG("metric type not implemented"); - } + Run_knn_extra_metrics run; + dispatch_VectorDistance( + d, mt, metric_arg, run, x, y, nx, ny, k, distances, indexes); } FlatCodesDistanceComputer* get_extra_distance_computer( @@ -211,27 +192,8 @@ FlatCodesDistanceComputer* get_extra_distance_computer( float metric_arg, size_t nb, const float* xb) { - switch (mt) { -#define HANDLE_VAR(kw) \ - case METRIC_##kw: { \ - VectorDistance vd = {(size_t)d, metric_arg}; \ - return new ExtraDistanceComputer>( \ - vd, xb, nb); \ - } - HANDLE_VAR(L2); - HANDLE_VAR(L1); - HANDLE_VAR(Linf); - HANDLE_VAR(Canberra); - HANDLE_VAR(BrayCurtis); - HANDLE_VAR(JensenShannon); - HANDLE_VAR(Lp); - HANDLE_VAR(Jaccard); - HANDLE_VAR(NaNEuclidean); - HANDLE_VAR(ABS_INNER_PRODUCT); -#undef HANDLE_VAR - default: - FAISS_THROW_MSG("metric type not implemented"); - } + Run_get_distance_computer run; + return dispatch_VectorDistance(d, mt, metric_arg, run, xb, nb); } } // namespace faiss diff --git a/faiss/utils/hamming_distance/hamdis-inl.h b/faiss/utils/hamming_distance/hamdis-inl.h index b830df38b6..dcd3fe2d12 100644 --- a/faiss/utils/hamming_distance/hamdis-inl.h +++ b/faiss/utils/hamming_distance/hamdis-inl.h @@ -55,7 +55,7 @@ SPECIALIZED_HC(64); /*************************************************************************** * Dispatching function that takes a code size and a consumer object * the consumer object should contain a retun type t and a operation template - * function f() that to be called to perform the operation. + * function f() that must be called to perform the operation. **************************************************************************/ template @@ -76,6 +76,7 @@ typename Consumer::T dispatch_HammingComputer( default: return consumer.template f(args...); } +#undef DISPATCH_HC } } // namespace faiss diff --git a/tests/test_standalone_codec.py b/tests/test_standalone_codec.py index 391b88b9dd..643e769f0e 100644 --- a/tests/test_standalone_codec.py +++ b/tests/test_standalone_codec.py @@ -14,6 +14,7 @@ from faiss.contrib.datasets import SyntheticDataset from faiss.contrib.inspect_tools import get_additive_quantizer_codebooks + class TestEncodeDecode(unittest.TestCase): def do_encode_twice(self, factory_key): @@ -263,6 +264,19 @@ def test_ZnSphereCodecAlt32(self): def test_ZnSphereCodecAlt24(self): self.run_ZnSphereCodecAlt(24, 14) + def test_lattice_index(self): + index = faiss.index_factory(96, "ZnLattice3x10_4") + rs = np.random.RandomState(123) + xq = rs.randn(10, 96).astype('float32') + xb = rs.randn(20, 96).astype('float32') + index.train(xb) + index.add(xb) + D, I = index.search(xq, 5) + for i in range(10): + recons = index.reconstruct_batch(I[i, :]) + ref_dis = ((recons - xq[i]) ** 2).sum(1) + np.testing.assert_allclose(D[i, :], ref_dis, atol=1e-4) + class TestBitstring(unittest.TestCase):