diff --git a/faiss/IndexFastScan.h b/faiss/IndexFastScan.h index b9787bf7cd..a0f5c592f0 100644 --- a/faiss/IndexFastScan.h +++ b/faiss/IndexFastScan.h @@ -133,6 +133,15 @@ struct IndexFastScan : Index { void merge_from(Index& otherIndex, idx_t add_id = 0) override; void check_compatible_for_merge(const Index& otherIndex) const override; + + /// standalone codes interface (but the codes are flattened) + size_t sa_code_size() const override { + return code_size; + } + + void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override { + compute_codes(bytes, n, x); + } }; struct FastScanStats { diff --git a/faiss/IndexIVF.h b/faiss/IndexIVF.h index 9018ac9387..a9fb2afb6f 100644 --- a/faiss/IndexIVF.h +++ b/faiss/IndexIVF.h @@ -436,7 +436,7 @@ struct IndexIVF : Index, IndexIVFInterface { size_t sa_code_size() const override; /** encode a set of vectors - * sa_encode will call encode_vector with include_listno=true + * sa_encode will call encode_vectors with include_listno=true * @param n nb of vectors to encode * @param x the vectors to encode * @param bytes output array for the codes diff --git a/faiss/IndexIVFAdditiveQuantizerFastScan.cpp b/faiss/IndexIVFAdditiveQuantizerFastScan.cpp index 91135a813f..93fad18636 100644 --- a/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +++ b/faiss/IndexIVFAdditiveQuantizerFastScan.cpp @@ -7,7 +7,6 @@ #include -#include #include #include @@ -67,7 +66,7 @@ void IndexIVFAdditiveQuantizerFastScan::init( } else { M = aq->M; } - init_fastscan(M, 4, nlist, metric, bbs); + init_fastscan(aq, M, 4, nlist, metric, bbs); max_train_points = 1024 * ksub * M; by_residual = true; @@ -440,13 +439,6 @@ void IndexIVFAdditiveQuantizerFastScan::compute_LUT( } } -void IndexIVFAdditiveQuantizerFastScan::sa_decode( - idx_t n, - const uint8_t* bytes, - float* x) const { - aq->decode(bytes, x, n); -} - /********** IndexIVFLocalSearchQuantizerFastScan ************/ IndexIVFLocalSearchQuantizerFastScan::IndexIVFLocalSearchQuantizerFastScan( Index* quantizer, diff --git a/faiss/IndexIVFAdditiveQuantizerFastScan.h b/faiss/IndexIVFAdditiveQuantizerFastScan.h index 77f535c923..75ec1d199a 100644 --- a/faiss/IndexIVFAdditiveQuantizerFastScan.h +++ b/faiss/IndexIVFAdditiveQuantizerFastScan.h @@ -96,8 +96,6 @@ struct IndexIVFAdditiveQuantizerFastScan : IndexIVFFastScan { const CoarseQuantized& cq, AlignedTable& dis_tables, AlignedTable& biases) const override; - - void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override; }; struct IndexIVFLocalSearchQuantizerFastScan diff --git a/faiss/IndexIVFFastScan.cpp b/faiss/IndexIVFFastScan.cpp index f95ad354a7..f031a51bba 100644 --- a/faiss/IndexIVFFastScan.cpp +++ b/faiss/IndexIVFFastScan.cpp @@ -55,6 +55,7 @@ IndexIVFFastScan::IndexIVFFastScan() { } void IndexIVFFastScan::init_fastscan( + Quantizer* fine_quantizer, size_t M, size_t nbits_init, size_t nlist, @@ -62,13 +63,16 @@ void IndexIVFFastScan::init_fastscan( int bbs_2) { FAISS_THROW_IF_NOT(bbs_2 % 32 == 0); FAISS_THROW_IF_NOT(nbits_init == 4); + FAISS_THROW_IF_NOT(fine_quantizer->d == d); + this->fine_quantizer = fine_quantizer; this->M = M; this->nbits = nbits_init; this->bbs = bbs_2; ksub = (1 << nbits_init); M2 = roundup(M, 2); code_size = M2 / 2; + FAISS_THROW_IF_NOT(code_size == fine_quantizer->code_size); is_trained = false; replace_invlists(new BlockInvertedLists(nlist, get_CodePacker()), true); @@ -1353,34 +1357,30 @@ void IndexIVFFastScan::reconstruct_from_offset( int64_t offset, float* recons) const { // unpack codes + size_t coarse_size = coarse_code_size(); + std::vector code(coarse_size + code_size, 0); + encode_listno(list_no, code.data()); InvertedLists::ScopedCodes list_codes(invlists, list_no); - std::vector code(code_size, 0); - BitstringWriter bsw(code.data(), code_size); + BitstringWriter bsw(code.data() + coarse_size, code_size); + for (size_t m = 0; m < M; m++) { uint8_t c = pq4_get_packed_element(list_codes.get(), bbs, M2, offset, m); bsw.write(c, nbits); } - sa_decode(1, code.data(), recons); - // add centroid to it - if (by_residual) { - std::vector centroid(d); - quantizer->reconstruct(list_no, centroid.data()); - for (int i = 0; i < d; ++i) { - recons[i] += centroid[i]; - } - } + sa_decode(1, code.data(), recons); } void IndexIVFFastScan::reconstruct_orig_invlists() { FAISS_THROW_IF_NOT(orig_invlists != nullptr); FAISS_THROW_IF_NOT(orig_invlists->list_size(0) == 0); - for (size_t list_no = 0; list_no < nlist; list_no++) { +#pragma omp parallel for if (nlist > 100) + for (idx_t list_no = 0; list_no < nlist; list_no++) { InvertedLists::ScopedCodes codes(invlists, list_no); InvertedLists::ScopedIds ids(invlists, list_no); - size_t list_size = orig_invlists->list_size(list_no); + size_t list_size = invlists->list_size(list_no); std::vector code(code_size, 0); for (size_t offset = 0; offset < list_size; offset++) { @@ -1400,6 +1400,30 @@ void IndexIVFFastScan::reconstruct_orig_invlists() { } } +void IndexIVFFastScan::sa_decode(idx_t n, const uint8_t* codes, float* x) + const { + size_t coarse_size = coarse_code_size(); + +#pragma omp parallel if (n > 1) + { + std::vector residual(d); + +#pragma omp for + for (idx_t i = 0; i < n; i++) { + const uint8_t* code = codes + i * (code_size + coarse_size); + int64_t list_no = decode_listno(code); + float* xi = x + i * d; + fine_quantizer->decode(code + coarse_size, xi, 1); + if (by_residual) { + quantizer->reconstruct(list_no, residual.data()); + for (size_t j = 0; j < d; j++) { + xi[j] += residual[j]; + } + } + } + } +} + IVFFastScanStats IVFFastScan_stats; } // namespace faiss diff --git a/faiss/IndexIVFFastScan.h b/faiss/IndexIVFFastScan.h index 054c803a18..48d6dafa1e 100644 --- a/faiss/IndexIVFFastScan.h +++ b/faiss/IndexIVFFastScan.h @@ -16,6 +16,7 @@ namespace faiss { struct NormTableScaler; struct SIMDResultHandlerToFloat; +struct Quantizer; /** Fast scan version of IVFPQ and IVFAQ. Works for 4-bit PQ/AQ for now. * @@ -59,6 +60,9 @@ struct IndexIVFFastScan : IndexIVF { int qbs = 0; size_t qbs2 = 0; + // quantizer used to pack the codes + Quantizer* fine_quantizer = nullptr; + IndexIVFFastScan( Index* quantizer, size_t d, @@ -68,7 +72,9 @@ struct IndexIVFFastScan : IndexIVF { IndexIVFFastScan(); + /// called by implementations void init_fastscan( + Quantizer* fine_quantizer, size_t M, size_t nbits, size_t nlist, @@ -225,6 +231,17 @@ struct IndexIVFFastScan : IndexIVF { // reconstruct orig invlists (for debugging) void reconstruct_orig_invlists(); + + /** Decode a set of vectors. + * + * NOTE: The codes in the IndexFastScan object are non-contiguous. + * But this method requires a contiguous representation. + * + * @param n number of vectors + * @param bytes input encoded vectors, size n * code_size + * @param x output vectors, size n * d + */ + void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override; }; struct IVFFastScanStats { diff --git a/faiss/IndexIVFPQFastScan.cpp b/faiss/IndexIVFPQFastScan.cpp index 9d1cdfcae3..95efaaaf89 100644 --- a/faiss/IndexIVFPQFastScan.cpp +++ b/faiss/IndexIVFPQFastScan.cpp @@ -42,7 +42,7 @@ IndexIVFPQFastScan::IndexIVFPQFastScan( : IndexIVFFastScan(quantizer, d, nlist, 0, metric), pq(d, M, nbits) { by_residual = false; // set to false by default because it's faster - init_fastscan(M, nbits, nlist, metric, bbs); + init_fastscan(&pq, M, nbits, nlist, metric, bbs); } IndexIVFPQFastScan::IndexIVFPQFastScan() { @@ -61,7 +61,8 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs) pq(orig.pq) { FAISS_THROW_IF_NOT(orig.pq.nbits == 4); - init_fastscan(orig.pq.M, orig.pq.nbits, orig.nlist, orig.metric_type, bbs); + init_fastscan( + &pq, orig.pq.M, orig.pq.nbits, orig.nlist, orig.metric_type, bbs); by_residual = orig.by_residual; ntotal = orig.ntotal; @@ -76,7 +77,8 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs) precomputed_table.nbytes()); } - for (size_t i = 0; i < nlist; i++) { +#pragma omp parallel for if (nlist > 100) + for (idx_t i = 0; i < nlist; i++) { size_t nb = orig.invlists->list_size(i); size_t nb2 = roundup(nb, bbs); AlignedTable tmp(nb2 * M2 / 2); @@ -282,28 +284,4 @@ void IndexIVFPQFastScan::compute_LUT( } } -void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* codes, float* x) - const { - size_t coarse_size = coarse_code_size(); - -#pragma omp parallel if (n > 1) - { - std::vector residual(d); - -#pragma omp for - for (idx_t i = 0; i < n; i++) { - const uint8_t* code = codes + i * (code_size + coarse_size); - int64_t list_no = decode_listno(code); - float* xi = x + i * d; - pq.decode(code + coarse_size, xi); - if (by_residual) { - quantizer->reconstruct(list_no, residual.data()); - for (size_t j = 0; j < d; j++) { - xi[j] += residual[j]; - } - } - } - } -} - } // namespace faiss diff --git a/faiss/IndexIVFPQFastScan.h b/faiss/IndexIVFPQFastScan.h index a2cce3266b..f2d722fea9 100644 --- a/faiss/IndexIVFPQFastScan.h +++ b/faiss/IndexIVFPQFastScan.h @@ -80,8 +80,6 @@ struct IndexIVFPQFastScan : IndexIVFFastScan { const CoarseQuantized& cq, AlignedTable& dis_tables, AlignedTable& biases) const override; - - void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override; }; } // namespace faiss diff --git a/faiss/IndexPQFastScan.h b/faiss/IndexPQFastScan.h index 495aad47e2..be16239331 100644 --- a/faiss/IndexPQFastScan.h +++ b/faiss/IndexPQFastScan.h @@ -47,15 +47,6 @@ struct IndexPQFastScan : IndexFastScan { void compute_float_LUT(float* lut, idx_t n, const float* x) const override; - /** Decode a set of vectors. - * - * NOTE: The codes in the IndexPQFastScan object are non-contiguous. - * But this method requires a contiguous representation. - * - * @param n number of vectors - * @param bytes input encoded vectors, size n * code_size - * @param x output vectors, size n * d - */ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override; }; diff --git a/tests/test_fast_scan_ivf.py b/tests/test_fast_scan_ivf.py index 55de784ad6..75c9500f82 100644 --- a/tests/test_fast_scan_ivf.py +++ b/tests/test_fast_scan_ivf.py @@ -543,6 +543,101 @@ def test_by_residual_odd_dim(self): self.do_test(by_residual=True, d=30) +class TestReconstruct(unittest.TestCase): + """ test reconstruct and sa_encode / sa_decode + (also for a few additive quantizer variants) """ + + def do_test(self, by_residual=False): + d = 32 + metric = faiss.METRIC_L2 + + ds = datasets.SyntheticDataset(d, 250, 200, 10) + + index = faiss.IndexIVFPQFastScan( + faiss.IndexFlatL2(d), d, 50, d // 2, 4, metric) + index.by_residual = by_residual + index.make_direct_map(True) + index.train(ds.get_train()) + index.add(ds.get_database()) + + # Test reconstruction + v123 = index.reconstruct(123) # single id + v120_10 = index.reconstruct_n(120, 10) + np.testing.assert_array_equal(v120_10[3], v123) + v120_10 = index.reconstruct_batch(np.arange(120, 130)) + np.testing.assert_array_equal(v120_10[3], v123) + + # Test original list reconstruction + index.orig_invlists = faiss.ArrayInvertedLists( + index.nlist, index.code_size) + index.reconstruct_orig_invlists() + assert index.orig_invlists.compute_ntotal() == index.ntotal + + # compare with non fast-scan index + index2 = faiss.IndexIVFPQ( + index.quantizer, d, 50, d // 2, 4, metric) + index2.by_residual = by_residual + index2.pq = index.pq + index2.is_trained = True + index2.replace_invlists(index.orig_invlists, False) + index2.ntotal = index.ntotal + index2.make_direct_map(True) + assert np.all(index.reconstruct(123) == index2.reconstruct(123)) + + def test_no_residual(self): + self.do_test(by_residual=False) + + def test_by_residual(self): + self.do_test(by_residual=True) + + def do_test_generic(self, factory_string, + by_residual=False, metric=faiss.METRIC_L2): + d = 32 + ds = datasets.SyntheticDataset(d, 250, 200, 10) + index = faiss.index_factory(ds.d, factory_string, metric) + if "IVF" in factory_string: + index.by_residual = by_residual + index.make_direct_map(True) + index.train(ds.get_train()) + index.add(ds.get_database()) + + # Test reconstruction + v123 = index.reconstruct(123) # single id + v120_10 = index.reconstruct_n(120, 10) + np.testing.assert_array_equal(v120_10[3], v123) + v120_10 = index.reconstruct_batch(np.arange(120, 130)) + np.testing.assert_array_equal(v120_10[3], v123) + codes = index.sa_encode(ds.get_database()[120:130]) + np.testing.assert_array_equal(index.sa_decode(codes), v120_10) + + # make sure pointers are correct after serialization + index2 = faiss.deserialize_index(faiss.serialize_index(index)) + codes2 = index2.sa_encode(ds.get_database()[120:130]) + np.testing.assert_array_equal(codes, codes2) + + + def test_ivfpq_residual(self): + self.do_test_generic("IVF20,PQ16x4fs", by_residual=True) + + def test_ivfpq_no_residual(self): + self.do_test_generic("IVF20,PQ16x4fs", by_residual=False) + + def test_pq(self): + self.do_test_generic("PQ16x4fs") + + def test_rq(self): + self.do_test_generic("RQ4x4fs", metric=faiss.METRIC_INNER_PRODUCT) + + def test_ivfprq(self): + self.do_test_generic("IVF20,PRQ8x2x4fs", by_residual=True, metric=faiss.METRIC_INNER_PRODUCT) + + def test_ivfprq_no_residual(self): + self.do_test_generic("IVF20,PRQ8x2x4fs", by_residual=False, metric=faiss.METRIC_INNER_PRODUCT) + + def test_prq(self): + self.do_test_generic("PRQ8x2x4fs", metric=faiss.METRIC_INNER_PRODUCT) + + class TestIsTrained(unittest.TestCase): def test_issue_2019(self):