Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 13 additions & 20 deletions demos/index_pq_flat_separate_codes_from_codebook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env -S grimaldi --kernel bento_kernel_faiss
#!/usr/bin/env -S grimaldi --kernel faiss_binary_local
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
Expand Down Expand Up @@ -39,7 +39,6 @@ def read_ids_codes():


def write_ids_codes(ids, codes):
# print(ids, codes)
np.save("/tmp/ids.npy", ids)
np.save("/tmp/codes.npy", codes.reshape(len(ids), -1))

Expand All @@ -49,46 +48,40 @@ def write_template_index(template_index):


def read_template_index_instance():
pq_index = faiss.read_index("/tmp/template.index")
return pq_index, faiss.IndexIDMap2(pq_index)
return faiss.read_index("/tmp/template.index")

""":py"""
# at train time

template_index = faiss.IndexPQ(d, M, nbits)
template_index = faiss.index_factory(d, f"IDMap2,PQ{M}x{nbits}")
template_index.train(training_data)
write_template_index(template_index)

""":py"""
# New database vector

template_instance_index, id_wrapper_index = read_template_index_instance()
database_vector_id, database_vector_float32 = np.int64(
np.random.rand() * 10000
), np.random.rand(1, d).astype("float32")
index = read_template_index_instance()
database_vector_id, database_vector_float32 = np.random.randint(10000), np.random.rand(1, d).astype(np.float32)
ids, codes = read_ids_codes()
# print(ids, codes)
code = template_instance_index.sa_encode(database_vector_float32)

code = index.index.sa_encode(database_vector_float32)

if ids is not None and codes is not None:
ids = np.concatenate((ids, [database_vector_id]))
codes = np.vstack((codes, code))
else:
ids = np.array([database_vector_id])
codes = np.array([code])

write_ids_codes(ids, codes)

""":py '1545041403561975'"""
""":py '331546060044009'"""
# then at query time
query_vector_float32 = np.random.rand(1, d).astype("float32")
template_index_instance, id_wrapper_index = read_template_index_instance()
query_vector_float32 = np.random.rand(1, d).astype(np.float32)
id_wrapper_index = read_template_index_instance()
ids, codes = read_ids_codes()

for code in codes:
for c in code:
template_index_instance.codes.push_back(int(c))
template_index_instance.ntotal = len(codes)
for i in ids:
id_wrapper_index.id_map.push_back(int(i))
id_wrapper_index.add_sa_codes(codes, ids)

id_wrapper_index.search(query_vector_float32, k=5)

Expand Down
4 changes: 4 additions & 0 deletions faiss/Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ void Index::sa_decode(idx_t, const uint8_t*, float*) const {
FAISS_THROW_MSG("standalone codec not implemented for this type of index");
}

void Index::add_sa_codes(idx_t, const uint8_t*, const idx_t*) {
FAISS_THROW_MSG("add_sa_codes not implemented for this type of index");
}

namespace {

// storage that explicitly reconstructs vectors before computing distances
Expand Down
7 changes: 7 additions & 0 deletions faiss/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,13 @@ struct Index {
* trained in the same way and have the same
* parameters). Otherwise throw. */
virtual void check_compatible_for_merge(const Index& otherIndex) const;

/** Add vectors that are computed with the standalone codec
*
* @param codes codes to add size n * sa_code_size()
* @param xids corresponding ids, size n
*/
virtual void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids);
};

} // namespace faiss
Expand Down
11 changes: 11 additions & 0 deletions faiss/IndexBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,15 @@ void IndexBinary::check_compatible_for_merge(
FAISS_THROW_MSG("check_compatible_for_merge() not implemented");
}

size_t IndexBinary::sa_code_size() const {
return code_size;
}

void IndexBinary::add_sa_codes(
idx_t n,
const uint8_t* codes,
const idx_t* xids) {
add_with_ids(n, codes, xids);
}

} // namespace faiss
6 changes: 6 additions & 0 deletions faiss/IndexBinary.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ struct IndexBinary {
* parameters). Otherwise throw. */
virtual void check_compatible_for_merge(
const IndexBinary& otherIndex) const;

/** size of the produced codes in bytes */
virtual size_t sa_code_size() const;

/** Same as add_with_ids for IndexBinary. */
virtual void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids);
};

} // namespace faiss
Expand Down
9 changes: 9 additions & 0 deletions faiss/IndexFlatCodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ void IndexFlatCodes::add(idx_t n, const float* x) {
ntotal += n;
}

void IndexFlatCodes::add_sa_codes(
idx_t n,
const uint8_t* codes_in,
const idx_t* /* xids */) {
codes.resize((ntotal + n) * code_size);
memcpy(codes.data() + (ntotal * code_size), codes_in, n * code_size);
ntotal += n;
}

void IndexFlatCodes::reset() {
codes.clear();
ntotal = 0;
Expand Down
3 changes: 3 additions & 0 deletions faiss/IndexFlatCodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ struct IndexFlatCodes : Index {

virtual void merge_from(Index& otherIndex, idx_t add_id = 0) override;

virtual void add_sa_codes(idx_t n, const uint8_t* x, const idx_t* xids)
override;

// permute_entries. perm of size ntotal maps new to old positions
void permute_entries(const idx_t* perm);
};
Expand Down
17 changes: 17 additions & 0 deletions faiss/IndexIDMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,23 @@ void IndexIDMapTemplate<IndexT>::add_with_ids(
this->ntotal = index->ntotal;
}

template <typename IndexT>
size_t IndexIDMapTemplate<IndexT>::sa_code_size() const {
return index->sa_code_size();
}

template <typename IndexT>
void IndexIDMapTemplate<IndexT>::add_sa_codes(
idx_t n,
const uint8_t* codes,
const idx_t* xids) {
index->add_sa_codes(n, codes, xids);
for (idx_t i = 0; i < n; i++) {
id_map.push_back(xids[i]);
}
this->ntotal = index->ntotal;
}

namespace {

/// RAII object to reset the IDSelector in the params object
Expand Down
3 changes: 3 additions & 0 deletions faiss/IndexIDMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ struct IndexIDMapTemplate : IndexT {
void merge_from(IndexT& otherIndex, idx_t add_id = 0) override;
void check_compatible_for_merge(const IndexT& otherIndex) const override;

size_t sa_code_size() const override;
void add_sa_codes(idx_t n, const uint8_t* x, const idx_t* xids) override;

~IndexIDMapTemplate() override;
IndexIDMapTemplate() {
own_fields = false;
Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ struct IndexIVF : Index, IndexIVFInterface {
* @param codes codes to add size n * sa_code_size()
* @param xids corresponding ids, size n
*/
void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids);
void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids)
override;

/** Train the encoder for the vectors.
*
Expand Down
3 changes: 1 addition & 2 deletions faiss/python/class_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,8 +812,7 @@ def replacement_permute_entries(self, perm):
replacement_range_search_preassigned, ignore_missing=True)
replace_method(the_class, 'sa_encode', replacement_sa_encode)
replace_method(the_class, 'sa_decode', replacement_sa_decode)
replace_method(the_class, 'add_sa_codes', replacement_add_sa_codes,
ignore_missing=True)
replace_method(the_class, 'add_sa_codes', replacement_add_sa_codes)
replace_method(the_class, 'permute_entries', replacement_permute_entries,
ignore_missing=True)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_standalone_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,27 @@ def test_transfer(self):
np.testing.assert_array_equal(Dref, Dnew)


class TestIDMap(unittest.TestCase):
def test_idmap(self):
ds = SyntheticDataset(32, 2000, 200, 100)
ids = np.random.randint(10000, size=ds.nb, dtype='int64')
index = faiss.index_factory(ds.d, "IDMap2,PQ8x2")
index.train(ds.get_train())
index.add_with_ids(ds.get_database(), ids)
Dref, Iref = index.search(ds.get_queries(), 10)

index.reset()

index.train(ds.get_train())
codes = index.index.sa_encode(ds.get_database())
index.add_sa_codes(codes, ids)
Dnew, Inew = index.search(ds.get_queries(), 10)

np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_array_equal(Dref, Dnew)



class TestRefine(unittest.TestCase):

def test_refine(self):
Expand Down