diff --git a/demos/index_pq_flat_separate_codes_from_codebook.py b/demos/index_pq_flat_separate_codes_from_codebook.py index 71b29a4931..982c805262 100644 --- a/demos/index_pq_flat_separate_codes_from_codebook.py +++ b/demos/index_pq_flat_separate_codes_from_codebook.py @@ -1,4 +1,4 @@ -#!/usr/bin/env -S grimaldi --kernel bento_kernel_faiss +#!/usr/bin/env -S grimaldi --kernel faiss_binary_local # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the @@ -39,7 +39,6 @@ def read_ids_codes(): def write_ids_codes(ids, codes): - # print(ids, codes) np.save("/tmp/ids.npy", ids) np.save("/tmp/codes.npy", codes.reshape(len(ids), -1)) @@ -49,46 +48,40 @@ def write_template_index(template_index): def read_template_index_instance(): - pq_index = faiss.read_index("/tmp/template.index") - return pq_index, faiss.IndexIDMap2(pq_index) + return faiss.read_index("/tmp/template.index") """:py""" # at train time -template_index = faiss.IndexPQ(d, M, nbits) +template_index = faiss.index_factory(d, f"IDMap2,PQ{M}x{nbits}") template_index.train(training_data) write_template_index(template_index) """:py""" # New database vector -template_instance_index, id_wrapper_index = read_template_index_instance() -database_vector_id, database_vector_float32 = np.int64( - np.random.rand() * 10000 -), np.random.rand(1, d).astype("float32") +index = read_template_index_instance() +database_vector_id, database_vector_float32 = np.random.randint(10000), np.random.rand(1, d).astype(np.float32) ids, codes = read_ids_codes() -# print(ids, codes) -code = template_instance_index.sa_encode(database_vector_float32) + +code = index.index.sa_encode(database_vector_float32) + if ids is not None and codes is not None: ids = np.concatenate((ids, [database_vector_id])) codes = np.vstack((codes, code)) else: ids = np.array([database_vector_id]) codes = np.array([code]) + write_ids_codes(ids, codes) -""":py '1545041403561975'""" +""":py '331546060044009'""" # then at query time -query_vector_float32 = np.random.rand(1, d).astype("float32") -template_index_instance, id_wrapper_index = read_template_index_instance() +query_vector_float32 = np.random.rand(1, d).astype(np.float32) +id_wrapper_index = read_template_index_instance() ids, codes = read_ids_codes() -for code in codes: - for c in code: - template_index_instance.codes.push_back(int(c)) -template_index_instance.ntotal = len(codes) -for i in ids: - id_wrapper_index.id_map.push_back(int(i)) +id_wrapper_index.add_sa_codes(codes, ids) id_wrapper_index.search(query_vector_float32, k=5) diff --git a/faiss/Index.cpp b/faiss/Index.cpp index 9f83dbe310..3530fcea15 100644 --- a/faiss/Index.cpp +++ b/faiss/Index.cpp @@ -134,6 +134,10 @@ void Index::sa_decode(idx_t, const uint8_t*, float*) const { FAISS_THROW_MSG("standalone codec not implemented for this type of index"); } +void Index::add_sa_codes(idx_t, const uint8_t*, const idx_t*) { + FAISS_THROW_MSG("add_sa_codes not implemented for this type of index"); +} + namespace { // storage that explicitly reconstructs vectors before computing distances diff --git a/faiss/Index.h b/faiss/Index.h index 2c3e62d718..e8bab0dd79 100644 --- a/faiss/Index.h +++ b/faiss/Index.h @@ -307,6 +307,13 @@ struct Index { * trained in the same way and have the same * parameters). Otherwise throw. */ virtual void check_compatible_for_merge(const Index& otherIndex) const; + + /** Add vectors that are computed with the standalone codec + * + * @param codes codes to add size n * sa_code_size() + * @param xids corresponding ids, size n + */ + virtual void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids); }; } // namespace faiss diff --git a/faiss/IndexBinary.cpp b/faiss/IndexBinary.cpp index 8d3e335aca..3c8165b766 100644 --- a/faiss/IndexBinary.cpp +++ b/faiss/IndexBinary.cpp @@ -103,4 +103,15 @@ void IndexBinary::check_compatible_for_merge( FAISS_THROW_MSG("check_compatible_for_merge() not implemented"); } +size_t IndexBinary::sa_code_size() const { + return code_size; +} + +void IndexBinary::add_sa_codes( + idx_t n, + const uint8_t* codes, + const idx_t* xids) { + add_with_ids(n, codes, xids); +} + } // namespace faiss diff --git a/faiss/IndexBinary.h b/faiss/IndexBinary.h index 2546b04a34..e9801a7db4 100644 --- a/faiss/IndexBinary.h +++ b/faiss/IndexBinary.h @@ -171,6 +171,12 @@ struct IndexBinary { * parameters). Otherwise throw. */ virtual void check_compatible_for_merge( const IndexBinary& otherIndex) const; + + /** size of the produced codes in bytes */ + virtual size_t sa_code_size() const; + + /** Same as add_with_ids for IndexBinary. */ + virtual void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids); }; } // namespace faiss diff --git a/faiss/IndexFlatCodes.cpp b/faiss/IndexFlatCodes.cpp index ec9a9a9b2d..54c86c6a9b 100644 --- a/faiss/IndexFlatCodes.cpp +++ b/faiss/IndexFlatCodes.cpp @@ -32,6 +32,15 @@ void IndexFlatCodes::add(idx_t n, const float* x) { ntotal += n; } +void IndexFlatCodes::add_sa_codes( + idx_t n, + const uint8_t* codes_in, + const idx_t* /* xids */) { + codes.resize((ntotal + n) * code_size); + memcpy(codes.data() + (ntotal * code_size), codes_in, n * code_size); + ntotal += n; +} + void IndexFlatCodes::reset() { codes.clear(); ntotal = 0; diff --git a/faiss/IndexFlatCodes.h b/faiss/IndexFlatCodes.h index 1dafffc6b3..809862f1e2 100644 --- a/faiss/IndexFlatCodes.h +++ b/faiss/IndexFlatCodes.h @@ -76,6 +76,9 @@ struct IndexFlatCodes : Index { virtual void merge_from(Index& otherIndex, idx_t add_id = 0) override; + virtual void add_sa_codes(idx_t n, const uint8_t* x, const idx_t* xids) + override; + // permute_entries. perm of size ntotal maps new to old positions void permute_entries(const idx_t* perm); }; diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp index 8c70fb2dec..8ad51d7588 100644 --- a/faiss/IndexIDMap.cpp +++ b/faiss/IndexIDMap.cpp @@ -83,6 +83,23 @@ void IndexIDMapTemplate::add_with_ids( this->ntotal = index->ntotal; } +template +size_t IndexIDMapTemplate::sa_code_size() const { + return index->sa_code_size(); +} + +template +void IndexIDMapTemplate::add_sa_codes( + idx_t n, + const uint8_t* codes, + const idx_t* xids) { + index->add_sa_codes(n, codes, xids); + for (idx_t i = 0; i < n; i++) { + id_map.push_back(xids[i]); + } + this->ntotal = index->ntotal; +} + namespace { /// RAII object to reset the IDSelector in the params object diff --git a/faiss/IndexIDMap.h b/faiss/IndexIDMap.h index 7b59d7b161..dd3887ae76 100644 --- a/faiss/IndexIDMap.h +++ b/faiss/IndexIDMap.h @@ -60,6 +60,9 @@ struct IndexIDMapTemplate : IndexT { void merge_from(IndexT& otherIndex, idx_t add_id = 0) override; void check_compatible_for_merge(const IndexT& otherIndex) const override; + size_t sa_code_size() const override; + void add_sa_codes(idx_t n, const uint8_t* x, const idx_t* xids) override; + ~IndexIDMapTemplate() override; IndexIDMapTemplate() { own_fields = false; diff --git a/faiss/IndexIVF.h b/faiss/IndexIVF.h index d8ed048075..9018ac9387 100644 --- a/faiss/IndexIVF.h +++ b/faiss/IndexIVF.h @@ -258,7 +258,8 @@ struct IndexIVF : Index, IndexIVFInterface { * @param codes codes to add size n * sa_code_size() * @param xids corresponding ids, size n */ - void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids); + void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids) + override; /** Train the encoder for the vectors. * diff --git a/faiss/python/class_wrappers.py b/faiss/python/class_wrappers.py index 98d995152a..607fdd6d29 100644 --- a/faiss/python/class_wrappers.py +++ b/faiss/python/class_wrappers.py @@ -812,8 +812,7 @@ def replacement_permute_entries(self, perm): replacement_range_search_preassigned, ignore_missing=True) replace_method(the_class, 'sa_encode', replacement_sa_encode) replace_method(the_class, 'sa_decode', replacement_sa_decode) - replace_method(the_class, 'add_sa_codes', replacement_add_sa_codes, - ignore_missing=True) + replace_method(the_class, 'add_sa_codes', replacement_add_sa_codes) replace_method(the_class, 'permute_entries', replacement_permute_entries, ignore_missing=True) diff --git a/tests/test_standalone_codec.py b/tests/test_standalone_codec.py index b6b3158af2..f52fb41cd9 100644 --- a/tests/test_standalone_codec.py +++ b/tests/test_standalone_codec.py @@ -360,6 +360,27 @@ def test_transfer(self): np.testing.assert_array_equal(Dref, Dnew) +class TestIDMap(unittest.TestCase): + def test_idmap(self): + ds = SyntheticDataset(32, 2000, 200, 100) + ids = np.random.randint(10000, size=ds.nb, dtype='int64') + index = faiss.index_factory(ds.d, "IDMap2,PQ8x2") + index.train(ds.get_train()) + index.add_with_ids(ds.get_database(), ids) + Dref, Iref = index.search(ds.get_queries(), 10) + + index.reset() + + index.train(ds.get_train()) + codes = index.index.sa_encode(ds.get_database()) + index.add_sa_codes(codes, ids) + Dnew, Inew = index.search(ds.get_queries(), 10) + + np.testing.assert_array_equal(Iref, Inew) + np.testing.assert_array_equal(Dref, Dnew) + + + class TestRefine(unittest.TestCase): def test_refine(self):