From 3ade2d0f9151652dde486b3c9717b3598711fccc Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Wed, 30 Aug 2023 08:56:40 -0700 Subject: [PATCH] Clean up batch comments + obey IO_FLAG_SKIP_PRECOMPUTE_TABLE (#3013) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3013 To avoid OOM when loading some RCQs, don't precompute cross product tables when io_flags contains bit IO_FLAG_SKIP_PRECOMPUTE_TABLE Differential Revision: D48448616 fbshipit-source-id: b798e9da3a3df41a56de2bc130d266958232f8cc --- faiss/IndexAdditiveQuantizer.cpp | 312 +++++++++++++++---------------- faiss/impl/AdditiveQuantizer.cpp | 2 + faiss/impl/AdditiveQuantizer.h | 6 +- faiss/impl/ResidualQuantizer.cpp | 67 +++---- faiss/impl/ResidualQuantizer.h | 13 +- faiss/impl/index_read.cpp | 39 ++-- faiss/python/__init__.py | 4 +- tests/test_residual_quantizer.py | 16 +- 8 files changed, 232 insertions(+), 227 deletions(-) diff --git a/faiss/IndexAdditiveQuantizer.cpp b/faiss/IndexAdditiveQuantizer.cpp index 153b766e4a..ffd4c524f6 100644 --- a/faiss/IndexAdditiveQuantizer.cpp +++ b/faiss/IndexAdditiveQuantizer.cpp @@ -5,9 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// quiet the noise -// clang-format off - #include #include @@ -21,7 +18,6 @@ #include #include - namespace faiss { /************************************************************************************** @@ -29,15 +25,13 @@ namespace faiss { **************************************************************************************/ IndexAdditiveQuantizer::IndexAdditiveQuantizer( - idx_t d, - AdditiveQuantizer* aq, - MetricType metric): - IndexFlatCodes(aq->code_size, d, metric), aq(aq) -{ + idx_t d, + AdditiveQuantizer* aq, + MetricType metric) + : IndexFlatCodes(aq->code_size, d, metric), aq(aq) { FAISS_THROW_IF_NOT(metric == METRIC_INNER_PRODUCT || metric == METRIC_L2); } - namespace { /************************************************************ @@ -45,21 +39,22 @@ namespace { ************************************************************/ template -struct AQDistanceComputerDecompress: FlatCodesDistanceComputer { +struct AQDistanceComputerDecompress : FlatCodesDistanceComputer { std::vector tmp; - const AdditiveQuantizer & aq; + const AdditiveQuantizer& aq; VectorDistance vd; size_t d; - AQDistanceComputerDecompress(const IndexAdditiveQuantizer &iaq, VectorDistance vd): - FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size), - tmp(iaq.d * 2), - aq(*iaq.aq), - vd(vd), - d(iaq.d) - {} + AQDistanceComputerDecompress( + const IndexAdditiveQuantizer& iaq, + VectorDistance vd) + : FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size), + tmp(iaq.d * 2), + aq(*iaq.aq), + vd(vd), + d(iaq.d) {} - const float *q; + const float* q; void set_query(const float* x) final { q = x; } @@ -70,7 +65,7 @@ struct AQDistanceComputerDecompress: FlatCodesDistanceComputer { return vd(tmp.data(), tmp.data() + d); } - float distance_to_code(const uint8_t *code) final { + float distance_to_code(const uint8_t* code) final { aq.decode(code, tmp.data(), 1); return vd(q, tmp.data()); } @@ -78,19 +73,17 @@ struct AQDistanceComputerDecompress: FlatCodesDistanceComputer { virtual ~AQDistanceComputerDecompress() {} }; - -template -struct AQDistanceComputerLUT: FlatCodesDistanceComputer { +template +struct AQDistanceComputerLUT : FlatCodesDistanceComputer { std::vector LUT; - const AdditiveQuantizer & aq; + const AdditiveQuantizer& aq; size_t d; - explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer &iaq): - FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size), - LUT(iaq.aq->total_codebook_size + iaq.d * 2), - aq(*iaq.aq), - d(iaq.d) - {} + explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer& iaq) + : FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size), + LUT(iaq.aq->total_codebook_size + iaq.d * 2), + aq(*iaq.aq), + d(iaq.d) {} float bias; void set_query(const float* x) final { @@ -104,26 +97,23 @@ struct AQDistanceComputerLUT: FlatCodesDistanceComputer { } float symmetric_dis(idx_t i, idx_t j) final { - float *tmp = LUT.data(); + float* tmp = LUT.data(); aq.decode(codes + i * d, tmp, 1); aq.decode(codes + j * d, tmp + d, 1); return fvec_L2sqr(tmp, tmp + d, d); } - float distance_to_code(const uint8_t *code) final { + float distance_to_code(const uint8_t* code) final { return bias + aq.compute_1_distance_LUT(code, LUT.data()); } virtual ~AQDistanceComputerLUT() {} }; - - /************************************************************ * scanning implementation for search ************************************************************/ - template void search_with_decompress( const IndexAdditiveQuantizer& ir, @@ -133,11 +123,11 @@ void search_with_decompress( const uint8_t* codes = ir.codes.data(); size_t ntotal = ir.ntotal; size_t code_size = ir.code_size; - const AdditiveQuantizer *aq = ir.aq; + const AdditiveQuantizer* aq = ir.aq; using SingleResultHandler = typename ResultHandler::SingleResultHandler; -#pragma omp parallel for if(res.nq > 100) +#pragma omp parallel for if (res.nq > 100) for (int64_t q = 0; q < res.nq; q++) { SingleResultHandler resi(res); resi.begin(q); @@ -152,13 +142,12 @@ void search_with_decompress( } } -template +template void search_with_LUT( const IndexAdditiveQuantizer& ir, const float* xq, - ResultHandler& res) -{ - const AdditiveQuantizer & aq = *ir.aq; + ResultHandler& res) { + const AdditiveQuantizer& aq = *ir.aq; const uint8_t* codes = ir.codes.data(); size_t ntotal = ir.ntotal; size_t code_size = aq.code_size; @@ -166,38 +155,34 @@ void search_with_LUT( size_t d = ir.d; using SingleResultHandler = typename ResultHandler::SingleResultHandler; - std::unique_ptr LUT(new float[nq * aq.total_codebook_size]); + std::unique_ptr LUT(new float[nq * aq.total_codebook_size]); aq.compute_LUT(nq, xq, LUT.get()); -#pragma omp parallel for if(nq > 100) +#pragma omp parallel for if (nq > 100) for (int64_t q = 0; q < nq; q++) { SingleResultHandler resi(res); resi.begin(q); std::vector tmp(aq.d); - const float *LUT_q = LUT.get() + aq.total_codebook_size * q; + const float* LUT_q = LUT.get() + aq.total_codebook_size * q; float bias = 0; - if (!is_IP) { // the LUT function returns ||y||^2 - 2 * , need to add ||x||^2 + if (!is_IP) { // the LUT function returns ||y||^2 - 2 * , need to + // add ||x||^2 bias = fvec_norm_L2sqr(xq + q * d, d); } for (size_t i = 0; i < ntotal; i++) { float dis = aq.compute_1_distance_LUT( - codes + i * code_size, - LUT_q - ); + codes + i * code_size, LUT_q); resi.add_result(dis + bias, i); } resi.end(); } - } - } // anonymous namespace - -FlatCodesDistanceComputer * IndexAdditiveQuantizer::get_FlatCodesDistanceComputer() const { - +FlatCodesDistanceComputer* IndexAdditiveQuantizer:: + get_FlatCodesDistanceComputer() const { if (aq->search_type == AdditiveQuantizer::ST_decompress) { if (metric_type == METRIC_L2) { using VD = VectorDistance; @@ -212,34 +197,36 @@ FlatCodesDistanceComputer * IndexAdditiveQuantizer::get_FlatCodesDistanceCompute } } else { if (metric_type == METRIC_INNER_PRODUCT) { - return new AQDistanceComputerLUT(*this); + return new AQDistanceComputerLUT< + true, + AdditiveQuantizer::ST_LUT_nonorm>(*this); } else { - switch(aq->search_type) { -#define DISPATCH(st) \ - case AdditiveQuantizer::st: \ - return new AQDistanceComputerLUT (*this);\ - break; - DISPATCH(ST_norm_float) - DISPATCH(ST_LUT_nonorm) - DISPATCH(ST_norm_qint8) - DISPATCH(ST_norm_qint4) - DISPATCH(ST_norm_cqint4) - case AdditiveQuantizer::ST_norm_cqint8: - case AdditiveQuantizer::ST_norm_lsq2x4: - case AdditiveQuantizer::ST_norm_rq2x4: - return new AQDistanceComputerLUT (*this);\ - break; + switch (aq->search_type) { +#define DISPATCH(st) \ + case AdditiveQuantizer::st: \ + return new AQDistanceComputerLUT(*this); \ + break; + DISPATCH(ST_norm_float) + DISPATCH(ST_LUT_nonorm) + DISPATCH(ST_norm_qint8) + DISPATCH(ST_norm_qint4) + DISPATCH(ST_norm_cqint4) + case AdditiveQuantizer::ST_norm_cqint8: + case AdditiveQuantizer::ST_norm_lsq2x4: + case AdditiveQuantizer::ST_norm_rq2x4: + return new AQDistanceComputerLUT< + false, + AdditiveQuantizer::ST_norm_cqint8>(*this); + break; #undef DISPATCH - default: - FAISS_THROW_FMT("search type %d not supported", aq->search_type); + default: + FAISS_THROW_FMT( + "search type %d not supported", aq->search_type); } } } } - - - void IndexAdditiveQuantizer::search( idx_t n, const float* x, @@ -247,8 +234,8 @@ void IndexAdditiveQuantizer::search( float* distances, idx_t* labels, const SearchParameters* params) const { - - FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index"); + FAISS_THROW_IF_NOT_MSG( + !params, "search params not supported for this index"); if (aq->search_type == AdditiveQuantizer::ST_decompress) { if (metric_type == METRIC_L2) { @@ -264,45 +251,46 @@ void IndexAdditiveQuantizer::search( } } else { if (metric_type == METRIC_INNER_PRODUCT) { - HeapResultHandler > rh(n, distances, labels, k); - search_with_LUT (*this, x, rh); + HeapResultHandler> rh(n, distances, labels, k); + search_with_LUT( + *this, x, rh); } else { - HeapResultHandler > rh(n, distances, labels, k); - switch(aq->search_type) { -#define DISPATCH(st) \ - case AdditiveQuantizer::st: \ - search_with_LUT (*this, x, rh);\ - break; - DISPATCH(ST_norm_float) - DISPATCH(ST_LUT_nonorm) - DISPATCH(ST_norm_qint8) - DISPATCH(ST_norm_qint4) - DISPATCH(ST_norm_cqint4) - case AdditiveQuantizer::ST_norm_cqint8: - case AdditiveQuantizer::ST_norm_lsq2x4: - case AdditiveQuantizer::ST_norm_rq2x4: - search_with_LUT (*this, x, rh); - break; + HeapResultHandler> rh(n, distances, labels, k); + switch (aq->search_type) { +#define DISPATCH(st) \ + case AdditiveQuantizer::st: \ + search_with_LUT(*this, x, rh); \ + break; + DISPATCH(ST_norm_float) + DISPATCH(ST_LUT_nonorm) + DISPATCH(ST_norm_qint8) + DISPATCH(ST_norm_qint4) + DISPATCH(ST_norm_cqint4) + case AdditiveQuantizer::ST_norm_cqint8: + case AdditiveQuantizer::ST_norm_lsq2x4: + case AdditiveQuantizer::ST_norm_rq2x4: + search_with_LUT( + *this, x, rh); + break; #undef DISPATCH - default: - FAISS_THROW_FMT("search type %d not supported", aq->search_type); + default: + FAISS_THROW_FMT( + "search type %d not supported", aq->search_type); } } - } } -void IndexAdditiveQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes) const { +void IndexAdditiveQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes) + const { return aq->compute_codes(x, bytes, n); } -void IndexAdditiveQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x) const { +void IndexAdditiveQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x) + const { return aq->decode(bytes, x, n); } - - - /************************************************************************************** * IndexResidualQuantizer **************************************************************************************/ @@ -313,8 +301,11 @@ IndexResidualQuantizer::IndexResidualQuantizer( size_t nbits, ///< number of bit per subvector index MetricType metric, Search_type_t search_type) - : IndexResidualQuantizer(d, std::vector(M, nbits), metric, search_type) { -} + : IndexResidualQuantizer( + d, + std::vector(M, nbits), + metric, + search_type) {} IndexResidualQuantizer::IndexResidualQuantizer( int d, @@ -326,14 +317,14 @@ IndexResidualQuantizer::IndexResidualQuantizer( is_trained = false; } -IndexResidualQuantizer::IndexResidualQuantizer() : IndexResidualQuantizer(0, 0, 0) {} +IndexResidualQuantizer::IndexResidualQuantizer() + : IndexResidualQuantizer(0, 0, 0) {} void IndexResidualQuantizer::train(idx_t n, const float* x) { rq.train(n, x); is_trained = true; } - /************************************************************************************** * IndexLocalSearchQuantizer **************************************************************************************/ @@ -344,31 +335,33 @@ IndexLocalSearchQuantizer::IndexLocalSearchQuantizer( size_t nbits, ///< number of bit per subvector index MetricType metric, Search_type_t search_type) - : IndexAdditiveQuantizer(d, &lsq, metric), lsq(d, M, nbits, search_type) { + : IndexAdditiveQuantizer(d, &lsq, metric), + lsq(d, M, nbits, search_type) { code_size = lsq.code_size; is_trained = false; } -IndexLocalSearchQuantizer::IndexLocalSearchQuantizer() : IndexLocalSearchQuantizer(0, 0, 0) {} +IndexLocalSearchQuantizer::IndexLocalSearchQuantizer() + : IndexLocalSearchQuantizer(0, 0, 0) {} void IndexLocalSearchQuantizer::train(idx_t n, const float* x) { lsq.train(n, x); is_trained = true; } - /************************************************************************************** * IndexProductResidualQuantizer **************************************************************************************/ IndexProductResidualQuantizer::IndexProductResidualQuantizer( - int d, ///< dimensionality of the input vectors + int d, ///< dimensionality of the input vectors size_t nsplits, ///< number of residual quantizers - size_t Msub, ///< number of subquantizers per RQ - size_t nbits, ///< number of bit per subvector index + size_t Msub, ///< number of subquantizers per RQ + size_t nbits, ///< number of bit per subvector index MetricType metric, Search_type_t search_type) - : IndexAdditiveQuantizer(d, &prq, metric), prq(d, nsplits, Msub, nbits, search_type) { + : IndexAdditiveQuantizer(d, &prq, metric), + prq(d, nsplits, Msub, nbits, search_type) { code_size = prq.code_size; is_trained = false; } @@ -381,19 +374,19 @@ void IndexProductResidualQuantizer::train(idx_t n, const float* x) { is_trained = true; } - /************************************************************************************** * IndexProductLocalSearchQuantizer **************************************************************************************/ IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer( - int d, ///< dimensionality of the input vectors + int d, ///< dimensionality of the input vectors size_t nsplits, ///< number of local search quantizers - size_t Msub, ///< number of subquantizers per LSQ - size_t nbits, ///< number of bit per subvector index + size_t Msub, ///< number of subquantizers per LSQ + size_t nbits, ///< number of bit per subvector index MetricType metric, Search_type_t search_type) - : IndexAdditiveQuantizer(d, &plsq, metric), plsq(d, nsplits, Msub, nbits, search_type) { + : IndexAdditiveQuantizer(d, &plsq, metric), + plsq(d, nsplits, Msub, nbits, search_type) { code_size = plsq.code_size; is_trained = false; } @@ -406,17 +399,15 @@ void IndexProductLocalSearchQuantizer::train(idx_t n, const float* x) { is_trained = true; } - /************************************************************************************** * AdditiveCoarseQuantizer **************************************************************************************/ AdditiveCoarseQuantizer::AdditiveCoarseQuantizer( - idx_t d, - AdditiveQuantizer* aq, - MetricType metric): - Index(d, metric), aq(aq) -{} + idx_t d, + AdditiveQuantizer* aq, + MetricType metric) + : Index(d, metric), aq(aq) {} void AdditiveCoarseQuantizer::add(idx_t, const float*) { FAISS_THROW_MSG("not applicable"); @@ -430,17 +421,16 @@ void AdditiveCoarseQuantizer::reset() { FAISS_THROW_MSG("not applicable"); } - void AdditiveCoarseQuantizer::train(idx_t n, const float* x) { if (verbose) { - printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n", size_t(n)); + printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n", + size_t(n)); } size_t norms_size = sizeof(float) << aq->tot_bits; - FAISS_THROW_IF_NOT_MSG ( - norms_size <= aq->max_mem_distances, - "the RCQ norms matrix will become too large, please reduce the number of quantization steps" - ); + FAISS_THROW_IF_NOT_MSG( + norms_size <= aq->max_mem_distances, + "the RCQ norms matrix will become too large, please reduce the number of quantization steps"); aq->train(n, x); is_trained = true; @@ -448,7 +438,8 @@ void AdditiveCoarseQuantizer::train(idx_t n, const float* x) { if (metric_type == METRIC_L2) { if (verbose) { - printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n", size_t(ntotal)); + printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n", + size_t(ntotal)); } // this is not necessary for the residualcoarsequantizer when // using beam search. We'll see if the memory overhead is too high @@ -463,16 +454,15 @@ void AdditiveCoarseQuantizer::search( idx_t k, float* distances, idx_t* labels, - const SearchParameters * params) const { - - FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index"); + const SearchParameters* params) const { + FAISS_THROW_IF_NOT_MSG( + !params, "search params not supported for this index"); if (metric_type == METRIC_INNER_PRODUCT) { aq->knn_centroids_inner_product(n, x, k, distances, labels); } else if (metric_type == METRIC_L2) { FAISS_THROW_IF_NOT(centroid_norms.size() == ntotal); - aq->knn_centroids_L2( - n, x, k, distances, labels, centroid_norms.data()); + aq->knn_centroids_L2(n, x, k, distances, labels, centroid_norms.data()); } } @@ -481,7 +471,7 @@ void AdditiveCoarseQuantizer::search( **************************************************************************************/ ResidualCoarseQuantizer::ResidualCoarseQuantizer( - int d, ///< dimensionality of the input vectors + int d, ///< dimensionality of the input vectors const std::vector& nbits, MetricType metric) : AdditiveCoarseQuantizer(d, &rq, metric), rq(d, nbits) { @@ -496,21 +486,30 @@ ResidualCoarseQuantizer::ResidualCoarseQuantizer( MetricType metric) : ResidualCoarseQuantizer(d, std::vector(M, nbits), metric) {} -ResidualCoarseQuantizer::ResidualCoarseQuantizer(): ResidualCoarseQuantizer(0, 0, 0) {} - - +ResidualCoarseQuantizer::ResidualCoarseQuantizer() + : ResidualCoarseQuantizer(0, 0, 0) {} void ResidualCoarseQuantizer::set_beam_factor(float new_beam_factor) { beam_factor = new_beam_factor; if (new_beam_factor > 0) { FAISS_THROW_IF_NOT(new_beam_factor >= 1.0); + if (rq.codebook_cross_products.size() == 0) { + rq.compute_codebook_tables(); + } return; - } else if (metric_type == METRIC_L2 && ntotal != centroid_norms.size()) { - if (verbose) { - printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n", size_t(ntotal)); + } else { + // new_beam_factor = -1: exhaustive computation. + // Does not use the cross_products + rq.codebook_cross_products.resize(0); + // but the centroid norms are necessary! + if (metric_type == METRIC_L2 && ntotal != centroid_norms.size()) { + if (verbose) { + printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n", + size_t(ntotal)); + } + centroid_norms.resize(ntotal); + aq->compute_centroid_norms(centroid_norms.data()); } - centroid_norms.resize(ntotal); - aq->compute_centroid_norms(centroid_norms.data()); } } @@ -520,13 +519,15 @@ void ResidualCoarseQuantizer::search( idx_t k, float* distances, idx_t* labels, - const SearchParameters * params_in - ) const { - + const SearchParameters* params_in) const { float beam_factor = this->beam_factor; if (params_in) { - auto params = dynamic_cast(params_in); - FAISS_THROW_IF_NOT_MSG(params, "need SearchParametersResidualCoarseQuantizer parameters"); + auto params = + dynamic_cast( + params_in); + FAISS_THROW_IF_NOT_MSG( + params, + "need SearchParametersResidualCoarseQuantizer parameters"); beam_factor = params->beam_factor; } @@ -571,6 +572,7 @@ void ResidualCoarseQuantizer::search( rq.refine_beam( n, 1, x, beam_size, codes.data(), nullptr, beam_distances.data()); + // pack int32 table #pragma omp parallel for if (n > 4000) for (idx_t i = 0; i < n; i++) { memcpy(distances + i * k, @@ -590,7 +592,8 @@ void ResidualCoarseQuantizer::search( } } -void ResidualCoarseQuantizer::initialize_from(const ResidualCoarseQuantizer &other) { +void ResidualCoarseQuantizer::initialize_from( + const ResidualCoarseQuantizer& other) { FAISS_THROW_IF_NOT(rq.M <= other.rq.M); rq.initialize_from(other.rq); set_beam_factor(other.beam_factor); @@ -598,7 +601,6 @@ void ResidualCoarseQuantizer::initialize_from(const ResidualCoarseQuantizer &oth ntotal = (idx_t)1 << aq->tot_bits; } - /************************************************************************************** * LocalSearchCoarseQuantizer **************************************************************************************/ @@ -613,12 +615,8 @@ LocalSearchCoarseQuantizer::LocalSearchCoarseQuantizer( is_trained = false; } - LocalSearchCoarseQuantizer::LocalSearchCoarseQuantizer() { aq = &lsq; } - - - } // namespace faiss diff --git a/faiss/impl/AdditiveQuantizer.cpp b/faiss/impl/AdditiveQuantizer.cpp index ff6eb4a98a..c39d870e6d 100644 --- a/faiss/impl/AdditiveQuantizer.cpp +++ b/faiss/impl/AdditiveQuantizer.cpp @@ -370,6 +370,8 @@ void AdditiveQuantizer::compute_LUT( namespace { +/* compute inner products of one query with all centroids, given a look-up + * table of all inner producst with codebook entries */ void compute_inner_prod_with_LUT( const AdditiveQuantizer& aq, const float* LUT, diff --git a/faiss/impl/AdditiveQuantizer.h b/faiss/impl/AdditiveQuantizer.h index 289d205299..054b5c7677 100644 --- a/faiss/impl/AdditiveQuantizer.h +++ b/faiss/impl/AdditiveQuantizer.h @@ -49,11 +49,13 @@ struct AdditiveQuantizer : Quantizer { /// encode a norm into norm_bits bits uint64_t encode_norm(float norm) const; + /// encode norm by non-uniform scalar quantization uint32_t encode_qcint( - float x) const; ///< encode norm by non-uniform scalar quantization + float x) const; + /// decode norm by non-uniform scalar quantization float decode_qcint(uint32_t c) - const; ///< decode norm by non-uniform scalar quantization + const; /// Encodes how search is performed and how vectors are encoded enum Search_type_t { diff --git a/faiss/impl/ResidualQuantizer.cpp b/faiss/impl/ResidualQuantizer.cpp index fa8248a60a..fdfcc8e364 100644 --- a/faiss/impl/ResidualQuantizer.cpp +++ b/faiss/impl/ResidualQuantizer.cpp @@ -125,6 +125,10 @@ void ResidualQuantizer::initialize_from( } } +/**************************************************************** + * Encoding steps, used both for training and search + */ + void beam_search_encode_step( size_t d, size_t K, @@ -277,6 +281,10 @@ void beam_search_encode_step( } } +/**************************************************************** + * Training + ****************************************************************/ + void ResidualQuantizer::train(size_t n, const float* x) { codebooks.resize(d * codebook_offsets.back()); @@ -568,7 +576,12 @@ size_t ResidualQuantizer::memory_per_point(int beam_size) const { return mem; } -// a namespace full of preallocated buffers +/**************************************************************** + * Encoding + ****************************************************************/ + +// a namespace full of preallocated buffers. This speeds up +// computations, instead of re-allocating them at every encoing step namespace { // Preallocated memory chunk for refine_beam_mp() call @@ -609,8 +622,6 @@ struct ComputeCodesAddCentroidsLUT1MemoryPool { RefineBeamLUTMemoryPool refine_beam_lut_pool; }; -} // namespace - // forward declaration void refine_beam_mp( const ResidualQuantizer& rq, @@ -743,6 +754,8 @@ void compute_codes_add_centroids_mp_lut1( centroids); } +} // namespace + void ResidualQuantizer::compute_codes_add_centroids( const float* x, uint8_t* codes_out, @@ -769,11 +782,6 @@ void ResidualQuantizer::compute_codes_add_centroids( cent = centroids + i0 * d; } - // compute_codes_add_centroids( - // x + i0 * d, - // codes_out + i0 * code_size, - // i1 - i0, - // cent); if (use_beam_LUT == 0) { compute_codes_add_centroids_mp_lut0( *this, @@ -794,6 +802,8 @@ void ResidualQuantizer::compute_codes_add_centroids( } } +namespace { + void refine_beam_mp( const ResidualQuantizer& rq, size_t n, @@ -873,15 +883,11 @@ void refine_beam_mp( codebooks_m, n, cur_beam_size, - // residuals.data(), residuals_ptr, m, - // codes.data(), codes_ptr, new_beam_size, - // new_codes.data(), new_codes_ptr, - // new_residuals.data(), new_residuals_ptr, pool.distances.data(), assign_index.get(), @@ -896,9 +902,6 @@ void refine_beam_mp( if (rq.verbose) { float sum_distances = 0; - // for (int j = 0; j < distances.size(); j++) { - // sum_distances += distances[j]; - // } for (int j = 0; j < distances_size; j++) { sum_distances += pool.distances[j]; } @@ -914,27 +917,22 @@ void refine_beam_mp( } if (out_codes) { - // memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0])); memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr)); } if (out_residuals) { - // memcpy(out_residuals, - // residuals.data(), - // residuals.size() * sizeof(residuals[0])); memcpy(out_residuals, residuals_ptr, residuals_size * sizeof(*residuals_ptr)); } if (out_distances) { - // memcpy(out_distances, - // distances.data(), - // distances.size() * sizeof(distances[0])); memcpy(out_distances, pool.distances.data(), distances_size * sizeof(pool.distances[0])); } } +} // anonymous namespace + void ResidualQuantizer::refine_beam( size_t n, size_t beam_size, @@ -1165,7 +1163,7 @@ void accum_and_finalize_tab( } } -} // namespace +} // anonymous namespace void beam_search_encode_step_tab( size_t K, @@ -1390,6 +1388,8 @@ void beam_search_encode_step_tab( } } +namespace { + // void refine_beam_LUT_mp( const ResidualQuantizer& rq, @@ -1443,13 +1443,9 @@ void refine_beam_LUT_mp( for (int m = 0; m < rq.M; m++) { int K = 1 << rq.nbits[m]; - // it is guaranteed that (new_beam_size <= than max_beam_size) == - // true + // it is guaranteed that (new_beam_size <= max_beam_size) int new_beam_size = std::min(beam_size * K, out_beam_size); - // std::vector new_codes(n * new_beam_size * (m + 1)); - // std::vector new_distances(n * new_beam_size); - codes_size = n * new_beam_size * (m + 1); distances_size = n * new_beam_size; @@ -1464,29 +1460,20 @@ void refine_beam_LUT_mp( rq.total_codebook_size, rq.cent_norms.data() + rq.codebook_offsets[m], m, - // codes.data(), codes_ptr, - // distances.data(), distances_ptr, new_beam_size, - // new_codes.data(), new_codes_ptr, - // new_distances.data() new_distances_ptr, rq.approx_topk_mode); - // codes.swap(new_codes); std::swap(codes_ptr, new_codes_ptr); - // distances.swap(new_distances); std::swap(distances_ptr, new_distances_ptr); beam_size = new_beam_size; if (rq.verbose) { float sum_distances = 0; - // for (int j = 0; j < distances.size(); j++) { - // sum_distances += distances[j]; - // } for (int j = 0; j < distances_size; j++) { sum_distances += distances_ptr[j]; } @@ -1501,19 +1488,17 @@ void refine_beam_LUT_mp( } if (out_codes) { - // memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0])); memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr)); } if (out_distances) { - // memcpy(out_distances, - // distances.data(), - // distances.size() * sizeof(distances[0])); memcpy(out_distances, distances_ptr, distances_size * sizeof(*distances_ptr)); } } +} // namespace + void ResidualQuantizer::refine_beam_LUT( size_t n, const float* query_norms, // size n diff --git a/faiss/impl/ResidualQuantizer.h b/faiss/impl/ResidualQuantizer.h index 042f96d232..c3ea1ff298 100644 --- a/faiss/impl/ResidualQuantizer.h +++ b/faiss/impl/ResidualQuantizer.h @@ -144,9 +144,7 @@ struct ResidualQuantizer : AdditiveQuantizer { */ size_t memory_per_point(int beam_size = -1) const; - /** Cross products used in codebook tables - * - * These are used to keep trak of norms of centroids. + /** Cross products used in codebook tables used for beam_LUT = 1 */ void compute_codebook_tables(); @@ -194,6 +192,15 @@ void beam_search_encode_step( /** Encode a set of vectors using their dot products with the codebooks * + * @param K number of vectors in the codebook + * @param n nb of vectors to encode + * @param beam_size input beam size + * @param codebook_cross_norms inner product of this codebook with the m + * previously encoded codebooks + * @param codebook_offsets offsets into codebook_cross_norms for each + * previous codebook + * @param query_cp dot products of query vectors with ??? + * @param cent_norms_i norms of centroids */ void beam_search_encode_step_tab( size_t K, diff --git a/faiss/impl/index_read.cpp b/faiss/impl/index_read.cpp index 423b22a9cc..4802e60ede 100644 --- a/faiss/impl/index_read.cpp +++ b/faiss/impl/index_read.cpp @@ -292,11 +292,17 @@ static void read_AdditiveQuantizer(AdditiveQuantizer* aq, IOReader* f) { aq->set_derived_values(); } -static void read_ResidualQuantizer(ResidualQuantizer* rq, IOReader* f) { +static void read_ResidualQuantizer( + ResidualQuantizer* rq, + IOReader* f, + int io_flags) { read_AdditiveQuantizer(rq, f); READ1(rq->train_type); READ1(rq->max_beam_size); - if (!(rq->train_type & ResidualQuantizer::Skip_codebook_tables)) { + if ((rq->train_type & ResidualQuantizer::Skip_codebook_tables) || + (io_flags & IO_FLAG_SKIP_PRECOMPUTE_TABLE)) { + // don't precompute the tables + } else { rq->compute_codebook_tables(); } } @@ -325,12 +331,13 @@ static void read_ProductAdditiveQuantizer( static void read_ProductResidualQuantizer( ProductResidualQuantizer* prq, - IOReader* f) { + IOReader* f, + int io_flags) { read_ProductAdditiveQuantizer(prq, f); for (size_t i = 0; i < prq->nsplits; i++) { auto rq = new ResidualQuantizer(); - read_ResidualQuantizer(rq, f); + read_ResidualQuantizer(rq, f, io_flags); prq->quantizers.push_back(rq); } } @@ -601,7 +608,7 @@ Index* read_index(IOReader* f, int io_flags) { if (h == fourcc("IxRQ")) { read_ResidualQuantizer_old(&idxr->rq, f); } else { - read_ResidualQuantizer(&idxr->rq, f); + read_ResidualQuantizer(&idxr->rq, f, io_flags); } READ1(idxr->code_size); READVECTOR(idxr->codes); @@ -616,7 +623,7 @@ Index* read_index(IOReader* f, int io_flags) { } else if (h == fourcc("IxPR")) { auto idxpr = new IndexProductResidualQuantizer(); read_index_header(idxpr, f); - read_ProductResidualQuantizer(&idxpr->prq, f); + read_ProductResidualQuantizer(&idxpr->prq, f, io_flags); READ1(idxpr->code_size); READVECTOR(idxpr->codes); idx = idxpr; @@ -630,8 +637,13 @@ Index* read_index(IOReader* f, int io_flags) { } else if (h == fourcc("ImRQ")) { ResidualCoarseQuantizer* idxr = new ResidualCoarseQuantizer(); read_index_header(idxr, f); - read_ResidualQuantizer(&idxr->rq, f); + read_ResidualQuantizer(&idxr->rq, f, io_flags); READ1(idxr->beam_factor); + if (io_flags & IO_FLAG_SKIP_PRECOMPUTE_TABLE) { + // then we force the beam factor to -1 + // which skips the table precomputation. + idxr->beam_factor = -1; + } idxr->set_beam_factor(idxr->beam_factor); idx = idxr; } else if ( @@ -656,13 +668,14 @@ Index* read_index(IOReader* f, int io_flags) { if (is_LSQ) { read_LocalSearchQuantizer((LocalSearchQuantizer*)idxaqfs->aq, f); } else if (is_RQ) { - read_ResidualQuantizer((ResidualQuantizer*)idxaqfs->aq, f); + read_ResidualQuantizer( + (ResidualQuantizer*)idxaqfs->aq, f, io_flags); } else if (is_PLSQ) { read_ProductLocalSearchQuantizer( (ProductLocalSearchQuantizer*)idxaqfs->aq, f); } else { read_ProductResidualQuantizer( - (ProductResidualQuantizer*)idxaqfs->aq, f); + (ProductResidualQuantizer*)idxaqfs->aq, f, io_flags); } READ1(idxaqfs->implem); @@ -704,13 +717,13 @@ Index* read_index(IOReader* f, int io_flags) { if (is_LSQ) { read_LocalSearchQuantizer((LocalSearchQuantizer*)ivaqfs->aq, f); } else if (is_RQ) { - read_ResidualQuantizer((ResidualQuantizer*)ivaqfs->aq, f); + read_ResidualQuantizer((ResidualQuantizer*)ivaqfs->aq, f, io_flags); } else if (is_PLSQ) { read_ProductLocalSearchQuantizer( (ProductLocalSearchQuantizer*)ivaqfs->aq, f); } else { read_ProductResidualQuantizer( - (ProductResidualQuantizer*)ivaqfs->aq, f); + (ProductResidualQuantizer*)ivaqfs->aq, f, io_flags); } READ1(ivaqfs->by_residual); @@ -832,13 +845,13 @@ Index* read_index(IOReader* f, int io_flags) { if (is_LSQ) { read_LocalSearchQuantizer((LocalSearchQuantizer*)iva->aq, f); } else if (is_RQ) { - read_ResidualQuantizer((ResidualQuantizer*)iva->aq, f); + read_ResidualQuantizer((ResidualQuantizer*)iva->aq, f, io_flags); } else if (is_PLSQ) { read_ProductLocalSearchQuantizer( (ProductLocalSearchQuantizer*)iva->aq, f); } else { read_ProductResidualQuantizer( - (ProductResidualQuantizer*)iva->aq, f); + (ProductResidualQuantizer*)iva->aq, f, io_flags); } READ1(iva->by_residual); READ1(iva->use_precomputed_table); diff --git a/faiss/python/__init__.py b/faiss/python/__init__.py index d650033096..427cb31625 100644 --- a/faiss/python/__init__.py +++ b/faiss/python/__init__.py @@ -298,10 +298,10 @@ def serialize_index(index): return vector_to_array(writer.data) -def deserialize_index(data): +def deserialize_index(data, io_flags=0): reader = VectorIOReader() copy_array_to_vector(data, reader.data) - return read_index(reader) + return read_index(reader, io_flags) def serialize_index_binary(index): diff --git a/tests/test_residual_quantizer.py b/tests/test_residual_quantizer.py index bb853c09d5..e2330fcb9e 100644 --- a/tests/test_residual_quantizer.py +++ b/tests/test_residual_quantizer.py @@ -642,15 +642,13 @@ def test_rcq_LUT(self): np.testing.assert_array_almost_equal(CDref, CDnew, decimal=5) np.testing.assert_array_equal(CIref, CInew) - def test_norms_oom(self): - "check if allocating too large norms tables raises an exception" - index = faiss.index_factory(32, "RQ20x8") - try: - index.train(np.zeros((100, 32), dtype="float32")) - except RuntimeError: - pass # ok - else: - self.assertFalse() + # check that you can load the index without computing the tables + quantizer.set_beam_factor(2.0) + self.assertNotEqual(quantizer.rq.codebook_cross_products.size(), 0) + quantizer3 = faiss.deserialize_index( + faiss.serialize_index(quantizer), faiss.IO_FLAG_SKIP_PRECOMPUTE_TABLE) + self.assertEqual(quantizer3.rq.codebook_cross_products.size(), 0) + CD3, CI3 = quantizer3.search(ds.get_queries(), 10) ###########################################################