diff --git a/faiss/IndexAdditiveQuantizer.cpp b/faiss/IndexAdditiveQuantizer.cpp index 719dcafbc9..2e410a79e0 100644 --- a/faiss/IndexAdditiveQuantizer.cpp +++ b/faiss/IndexAdditiveQuantizer.cpp @@ -273,6 +273,7 @@ void IndexAdditiveQuantizer::search( DISPATCH(ST_norm_qint8) DISPATCH(ST_norm_qint4) DISPATCH(ST_norm_cqint4) + DISPATCH(ST_norm_from_LUT) case AdditiveQuantizer::ST_norm_cqint8: case AdditiveQuantizer::ST_norm_lsq2x4: case AdditiveQuantizer::ST_norm_rq2x4: diff --git a/faiss/IndexIVFAdditiveQuantizer.cpp b/faiss/IndexIVFAdditiveQuantizer.cpp index f0fde48b0e..19b23dab99 100644 --- a/faiss/IndexIVFAdditiveQuantizer.cpp +++ b/faiss/IndexIVFAdditiveQuantizer.cpp @@ -275,7 +275,7 @@ InvertedListScanner* IndexIVFAdditiveQuantizer::get_InvertedListScanner( return new AQInvertedListScannerLUT( \ *this, store_pairs); A(ST_LUT_nonorm) - // A(ST_norm_from_LUT) + A(ST_norm_from_LUT) A(ST_norm_float) A(ST_norm_qint8) A(ST_norm_qint4) diff --git a/faiss/impl/AdditiveQuantizer.cpp b/faiss/impl/AdditiveQuantizer.cpp index 42d37f32a9..b7bc5af69e 100644 --- a/faiss/impl/AdditiveQuantizer.cpp +++ b/faiss/impl/AdditiveQuantizer.cpp @@ -152,6 +152,40 @@ void AdditiveQuantizer::train_norm(size_t n, const float* norms) { } } +void AdditiveQuantizer::compute_codebook_tables() { + centroid_norms.resize(total_codebook_size); + fvec_norms_L2sqr( + centroid_norms.data(), codebooks.data(), d, total_codebook_size); + size_t cross_table_size = 0; + for (int m = 0; m < M; m++) { + size_t K = (size_t)1 << nbits[m]; + cross_table_size += K * codebook_offsets[m]; + } + codebook_cross_products.resize(cross_table_size); + size_t ofs = 0; + for (int m = 1; m < M; m++) { + FINTEGER ki = (size_t)1 << nbits[m]; + FINTEGER kk = codebook_offsets[m]; + FINTEGER di = d; + float zero = 0, one = 1; + assert(ofs + ki * kk <= cross_table_size); + sgemm_("Transposed", + "Not transposed", + &ki, + &kk, + &di, + &one, + codebooks.data() + d * kk, + &di, + codebooks.data(), + &di, + &zero, + codebook_cross_products.data() + ofs, + &ki); + ofs += ki * kk; + } +} + namespace { // TODO @@ -471,7 +505,6 @@ namespace { float accumulate_IPs( const AdditiveQuantizer& aq, BitstringReader& bs, - const uint8_t* codes, const float* LUT) { float accu = 0; for (int m = 0; m < aq.M; m++) { @@ -483,6 +516,29 @@ float accumulate_IPs( return accu; } +float compute_norm_from_LUT(const AdditiveQuantizer& aq, BitstringReader& bs) { + float accu = 0; + std::vector idx(aq.M); + const float* c = aq.codebook_cross_products.data(); + for (int m = 0; m < aq.M; m++) { + size_t nbit = aq.nbits[m]; + int i = bs.read(nbit); + size_t K = 1 << nbit; + idx[m] = i; + + accu += aq.centroid_norms[aq.codebook_offsets[m] + i]; + + for (int l = 0; l < m; l++) { + int j = idx[l]; + accu += 2 * c[j * K + i]; + c += (1 << aq.nbits[l]) * K; + } + } + // FAISS_THROW_IF_NOT(c == aq.codebook_cross_products.data() + + // aq.codebook_cross_products.size()); + return accu; +} + } // anonymous namespace template <> @@ -491,7 +547,7 @@ float AdditiveQuantizer:: const uint8_t* codes, const float* LUT) const { BitstringReader bs(codes, code_size); - return accumulate_IPs(*this, bs, codes, LUT); + return accumulate_IPs(*this, bs, LUT); } template <> @@ -500,7 +556,7 @@ float AdditiveQuantizer:: const uint8_t* codes, const float* LUT) const { BitstringReader bs(codes, code_size); - return -accumulate_IPs(*this, bs, codes, LUT); + return -accumulate_IPs(*this, bs, LUT); } template <> @@ -509,7 +565,7 @@ float AdditiveQuantizer:: const uint8_t* codes, const float* LUT) const { BitstringReader bs(codes, code_size); - float accu = accumulate_IPs(*this, bs, codes, LUT); + float accu = accumulate_IPs(*this, bs, LUT); uint32_t norm_i = bs.read(32); float norm2; memcpy(&norm2, &norm_i, 4); @@ -522,7 +578,7 @@ float AdditiveQuantizer:: const uint8_t* codes, const float* LUT) const { BitstringReader bs(codes, code_size); - float accu = accumulate_IPs(*this, bs, codes, LUT); + float accu = accumulate_IPs(*this, bs, LUT); uint32_t norm_i = bs.read(8); float norm2 = decode_qcint(norm_i); return norm2 - 2 * accu; @@ -534,7 +590,7 @@ float AdditiveQuantizer:: const uint8_t* codes, const float* LUT) const { BitstringReader bs(codes, code_size); - float accu = accumulate_IPs(*this, bs, codes, LUT); + float accu = accumulate_IPs(*this, bs, LUT); uint32_t norm_i = bs.read(4); float norm2 = decode_qcint(norm_i); return norm2 - 2 * accu; @@ -546,7 +602,7 @@ float AdditiveQuantizer:: const uint8_t* codes, const float* LUT) const { BitstringReader bs(codes, code_size); - float accu = accumulate_IPs(*this, bs, codes, LUT); + float accu = accumulate_IPs(*this, bs, LUT); uint32_t norm_i = bs.read(8); float norm2 = decode_qint8(norm_i, norm_min, norm_max); return norm2 - 2 * accu; @@ -558,10 +614,23 @@ float AdditiveQuantizer:: const uint8_t* codes, const float* LUT) const { BitstringReader bs(codes, code_size); - float accu = accumulate_IPs(*this, bs, codes, LUT); + float accu = accumulate_IPs(*this, bs, LUT); uint32_t norm_i = bs.read(4); float norm2 = decode_qint4(norm_i, norm_min, norm_max); return norm2 - 2 * accu; } +template <> +float AdditiveQuantizer:: + compute_1_distance_LUT( + const uint8_t* codes, + const float* LUT) const { + FAISS_THROW_IF_NOT(codebook_cross_products.size() > 0); + BitstringReader bs(codes, code_size); + float accu = accumulate_IPs(*this, bs, LUT); + BitstringReader bs2(codes, code_size); + float norm2 = compute_norm_from_LUT(*this, bs2); + return norm2 - 2 * accu; +} + } // namespace faiss diff --git a/faiss/impl/AdditiveQuantizer.h b/faiss/impl/AdditiveQuantizer.h index 8eceabe5d6..0a6cd241ce 100644 --- a/faiss/impl/AdditiveQuantizer.h +++ b/faiss/impl/AdditiveQuantizer.h @@ -29,6 +29,8 @@ struct AdditiveQuantizer : Quantizer { std::vector codebooks; ///< codebooks // derived values + /// codebook #1 is stored in rows codebook_offsets[i]:codebook_offsets[i+1] + /// in the codebooks table of size total_codebook_size by d std::vector codebook_offsets; size_t tot_bits = 0; ///< total number of bits (indexes + norms) size_t norm_bits = 0; ///< bits allocated for the norms @@ -38,9 +40,19 @@ struct AdditiveQuantizer : Quantizer { bool verbose = false; ///< verbose during training? bool is_trained = false; ///< is trained or not - IndexFlat1D qnorm; ///< store and search norms - std::vector norm_tabs; ///< store norms of codebook entries for 4-bit - ///< fastscan search + /// auxiliary data for ST_norm_lsq2x4 and ST_norm_rq2x4 + /// store norms of codebook entries for 4-bit fastscan + std::vector norm_tabs; + IndexFlat1D qnorm; ///< store and search norms + + void compute_codebook_tables(); + + /// norms of all codebook entries (size total_codebook_size) + std::vector centroid_norms; + + /// dot products of all codebook entries with the previous codebooks + /// size sum(codebook_offsets[m] * 2^nbits[m], m=0..M-1) + std::vector codebook_cross_products; /// norms and distance matrixes with beam search can get large, so use this /// to control for the amount of memory that can be allocated diff --git a/faiss/impl/ResidualQuantizer.cpp b/faiss/impl/ResidualQuantizer.cpp index 0892296dfd..f474ac64e3 100644 --- a/faiss/impl/ResidualQuantizer.cpp +++ b/faiss/impl/ResidualQuantizer.cpp @@ -492,40 +492,6 @@ void ResidualQuantizer::refine_beam( * Functions using the dot products between codebook entries *******************************************************************/ -void ResidualQuantizer::compute_codebook_tables() { - cent_norms.resize(total_codebook_size); - fvec_norms_L2sqr( - cent_norms.data(), codebooks.data(), d, total_codebook_size); - size_t cross_table_size = 0; - for (int m = 0; m < M; m++) { - size_t K = (size_t)1 << nbits[m]; - cross_table_size += K * codebook_offsets[m]; - } - codebook_cross_products.resize(cross_table_size); - size_t ofs = 0; - for (int m = 1; m < M; m++) { - FINTEGER ki = (size_t)1 << nbits[m]; - FINTEGER kk = codebook_offsets[m]; - FINTEGER di = d; - float zero = 0, one = 1; - assert(ofs + ki * kk <= cross_table_size); - sgemm_("Transposed", - "Not transposed", - &ki, - &kk, - &di, - &one, - codebooks.data() + d * kk, - &di, - codebooks.data(), - &di, - &zero, - codebook_cross_products.data() + ofs, - &ki); - ofs += ki * kk; - } -} - void ResidualQuantizer::refine_beam_LUT( size_t n, const float* query_norms, // size n diff --git a/faiss/impl/ResidualQuantizer.h b/faiss/impl/ResidualQuantizer.h index 004f7cabd4..5a1ea63e66 100644 --- a/faiss/impl/ResidualQuantizer.h +++ b/faiss/impl/ResidualQuantizer.h @@ -143,16 +143,6 @@ struct ResidualQuantizer : AdditiveQuantizer { * @param beam_size if != -1, override the beam size */ size_t memory_per_point(int beam_size = -1) const; - - /** Cross products used in codebook tables used for beam_LUT = 1 - */ - void compute_codebook_tables(); - - /// dot products of all codebook entries with the previous codebooks - /// size sum(codebook_offsets[m] * 2^nbits[m], m=0..M-1) - std::vector codebook_cross_products; - /// norms of all codebook entries (size total_codebook_size) - std::vector cent_norms; }; } // namespace faiss diff --git a/faiss/impl/residual_quantizer_encode_steps.cpp b/faiss/impl/residual_quantizer_encode_steps.cpp index c52988ab24..9fcdd9e1d2 100644 --- a/faiss/impl/residual_quantizer_encode_steps.cpp +++ b/faiss/impl/residual_quantizer_encode_steps.cpp @@ -809,7 +809,7 @@ void refine_beam_LUT_mp( rq.codebook_offsets.data(), query_cp + rq.codebook_offsets[m], rq.total_codebook_size, - rq.cent_norms.data() + rq.codebook_offsets[m], + rq.centroid_norms.data() + rq.codebook_offsets[m], m, codes_ptr, distances_ptr, diff --git a/tests/test_residual_quantizer.py b/tests/test_residual_quantizer.py index f4381607e1..6079ca75e1 100644 --- a/tests/test_residual_quantizer.py +++ b/tests/test_residual_quantizer.py @@ -457,6 +457,37 @@ def test_search_decompress(self): # recalls are {1: 0.05, 10: 0.37, 100: 0.37} self.assertGreater(recalls[10], 0.35) + def do_exact_search_equiv(self, norm_type): + """ searching with this normalization should yield + exactly the same results as decompression (because the + norms are exact) """ + ds = datasets.SyntheticDataset(32, 1000, 1000, 100) + + # decompresses by default + ir = faiss.IndexResidualQuantizer(ds.d, 3, 6) + ir.rq.train_type = faiss.ResidualQuantizer.Train_default + ir.train(ds.get_train()) + ir.add(ds.get_database()) + Dref, Iref = ir.search(ds.get_queries(), 10) + + ir2 = faiss.IndexResidualQuantizer( + ds.d, 3, 6, faiss.METRIC_L2, norm_type) + + # assumes training is reproducible + ir2.rq.train_type = faiss.ResidualQuantizer.Train_default + ir2.train(ds.get_train()) + ir2.add(ds.get_database()) + D, I = ir2.search(ds.get_queries(), 10) + + np.testing.assert_allclose(D, Dref, atol=1e-5) + np.testing.assert_array_equal(I, Iref) + + def test_exact_equiv_norm_float(self): + self.do_exact_search_equiv(faiss.AdditiveQuantizer.ST_norm_float) + + def test_exact_equiv_norm_from_LUT(self): + self.do_exact_search_equiv(faiss.AdditiveQuantizer.ST_norm_from_LUT) + def test_reestimate_codebook(self): ds = datasets.SyntheticDataset(32, 1000, 1000, 100) @@ -858,6 +889,9 @@ def test_norm_cqint(self): self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_cqint8) self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_cqint4) + def test_norm_from_LUT(self): + self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_from_LUT) + def test_factory(self): index = faiss.index_factory(12, "IVF1024,RQ8x8_Nfloat") self.assertEqual(index.nlist, 1024) @@ -1105,7 +1139,7 @@ def test_precomp(self): ofs += kk * K np.testing.assert_allclose(py_table, cpp_table, atol=1e-5) - cent_norms = faiss.vector_to_array(rq.cent_norms) + cent_norms = faiss.vector_to_array(rq.centroid_norms) np.testing.assert_array_almost_equal( np.hstack(cent_norms_ref), cent_norms, decimal=5)