Skip to content
20 changes: 8 additions & 12 deletions faiss/IndexIVFFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1353,34 +1353,30 @@ void IndexIVFFastScan::reconstruct_from_offset(
int64_t offset,
float* recons) const {
// unpack codes
size_t coarse_size = coarse_code_size();
std::vector<uint8_t> code(coarse_size + code_size, 0);
encode_listno(list_no, code.data());
InvertedLists::ScopedCodes list_codes(invlists, list_no);
std::vector<uint8_t> code(code_size, 0);
BitstringWriter bsw(code.data(), code_size);
BitstringWriter bsw(code.data() + coarse_size, code_size);

for (size_t m = 0; m < M; m++) {
uint8_t c =
pq4_get_packed_element(list_codes.get(), bbs, M2, offset, m);
bsw.write(c, nbits);
}
sa_decode(1, code.data(), recons);

// add centroid to it
if (by_residual) {
std::vector<float> centroid(d);
quantizer->reconstruct(list_no, centroid.data());
for (int i = 0; i < d; ++i) {
recons[i] += centroid[i];
}
}
sa_decode(1, code.data(), recons);
}

void IndexIVFFastScan::reconstruct_orig_invlists() {
FAISS_THROW_IF_NOT(orig_invlists != nullptr);
FAISS_THROW_IF_NOT(orig_invlists->list_size(0) == 0);

#pragma omp parallel for if (nlist > 100)
for (size_t list_no = 0; list_no < nlist; list_no++) {
InvertedLists::ScopedCodes codes(invlists, list_no);
InvertedLists::ScopedIds ids(invlists, list_no);
size_t list_size = orig_invlists->list_size(list_no);
size_t list_size = invlists->list_size(list_no);
std::vector<uint8_t> code(code_size, 0);

for (size_t offset = 0; offset < list_size; offset++) {
Expand Down
1 change: 1 addition & 0 deletions faiss/IndexIVFPQFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
precomputed_table.nbytes());
}

#pragma omp parallel for if (nlist > 100)
for (size_t i = 0; i < nlist; i++) {
size_t nb = orig.invlists->list_size(i);
size_t nb2 = roundup(nb, bbs);
Expand Down
31 changes: 31 additions & 0 deletions tests/test_fast_scan_ivf.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,37 @@ def test_by_residual_odd_dim(self):
self.do_test(by_residual=True, d=30)


class TestReconstruct(unittest.TestCase):

def do_test(self, by_residual=False):
d = 32
metric = faiss.METRIC_L2

ds = datasets.SyntheticDataset(d, 2000, 5000, 200)

index = faiss.IndexIVFPQFastScan(faiss.IndexFlatL2(d), d, 50, d // 2, 4, metric)
index.by_residual = by_residual
index.make_direct_map(True)
index.train(ds.get_train())
index.add(ds.get_database())

# Test reconstruction
index.reconstruct(123) # single id
index.reconstruct_n(123, 10) # single id
index.reconstruct_batch(np.arange(10))

# Test original list reconstruction
index.orig_invlists = faiss.ArrayInvertedLists(index.nlist, index.code_size)
index.reconstruct_orig_invlists()
assert index.orig_invlists.compute_ntotal() == index.ntotal

def test_no_residual(self):
self.do_test(by_residual=False)

def test_by_residual(self):
self.do_test(by_residual=True)


class TestIsTrained(unittest.TestCase):

def test_issue_2019(self):
Expand Down
Loading