Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions faiss/IndexFastScan.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ struct IndexFastScan : Index {

void merge_from(Index& otherIndex, idx_t add_id = 0) override;
void check_compatible_for_merge(const Index& otherIndex) const override;

/// standalone codes interface (but the codes are flattened)
size_t sa_code_size() const override {
return code_size;
}

void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override {
compute_codes(bytes, n, x);
}
};

struct FastScanStats {
Expand Down
2 changes: 1 addition & 1 deletion faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ struct IndexIVF : Index, IndexIVFInterface {
size_t sa_code_size() const override;

/** encode a set of vectors
* sa_encode will call encode_vector with include_listno=true
* sa_encode will call encode_vectors with include_listno=true
* @param n nb of vectors to encode
* @param x the vectors to encode
* @param bytes output array for the codes
Expand Down
10 changes: 1 addition & 9 deletions faiss/IndexIVFAdditiveQuantizerFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

#include <faiss/IndexIVFAdditiveQuantizerFastScan.h>

#include <cassert>
#include <cinttypes>
#include <cstdio>

Expand Down Expand Up @@ -67,7 +66,7 @@ void IndexIVFAdditiveQuantizerFastScan::init(
} else {
M = aq->M;
}
init_fastscan(M, 4, nlist, metric, bbs);
init_fastscan(aq, M, 4, nlist, metric, bbs);

max_train_points = 1024 * ksub * M;
by_residual = true;
Expand Down Expand Up @@ -440,13 +439,6 @@ void IndexIVFAdditiveQuantizerFastScan::compute_LUT(
}
}

void IndexIVFAdditiveQuantizerFastScan::sa_decode(
idx_t n,
const uint8_t* bytes,
float* x) const {
aq->decode(bytes, x, n);
}

/********** IndexIVFLocalSearchQuantizerFastScan ************/
IndexIVFLocalSearchQuantizerFastScan::IndexIVFLocalSearchQuantizerFastScan(
Index* quantizer,
Expand Down
2 changes: 0 additions & 2 deletions faiss/IndexIVFAdditiveQuantizerFastScan.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ struct IndexIVFAdditiveQuantizerFastScan : IndexIVFFastScan {
const CoarseQuantized& cq,
AlignedTable<float>& dis_tables,
AlignedTable<float>& biases) const override;

void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
};

struct IndexIVFLocalSearchQuantizerFastScan
Expand Down
50 changes: 37 additions & 13 deletions faiss/IndexIVFFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,24 @@ IndexIVFFastScan::IndexIVFFastScan() {
}

void IndexIVFFastScan::init_fastscan(
Quantizer* fine_quantizer,
size_t M,
size_t nbits_init,
size_t nlist,
MetricType /* metric */,
int bbs_2) {
FAISS_THROW_IF_NOT(bbs_2 % 32 == 0);
FAISS_THROW_IF_NOT(nbits_init == 4);
FAISS_THROW_IF_NOT(fine_quantizer->d == d);

this->fine_quantizer = fine_quantizer;
this->M = M;
this->nbits = nbits_init;
this->bbs = bbs_2;
ksub = (1 << nbits_init);
M2 = roundup(M, 2);
code_size = M2 / 2;
FAISS_THROW_IF_NOT(code_size == fine_quantizer->code_size);

is_trained = false;
replace_invlists(new BlockInvertedLists(nlist, get_CodePacker()), true);
Expand Down Expand Up @@ -1353,34 +1357,30 @@ void IndexIVFFastScan::reconstruct_from_offset(
int64_t offset,
float* recons) const {
// unpack codes
size_t coarse_size = coarse_code_size();
std::vector<uint8_t> code(coarse_size + code_size, 0);
encode_listno(list_no, code.data());
InvertedLists::ScopedCodes list_codes(invlists, list_no);
std::vector<uint8_t> code(code_size, 0);
BitstringWriter bsw(code.data(), code_size);
BitstringWriter bsw(code.data() + coarse_size, code_size);

for (size_t m = 0; m < M; m++) {
uint8_t c =
pq4_get_packed_element(list_codes.get(), bbs, M2, offset, m);
bsw.write(c, nbits);
}
sa_decode(1, code.data(), recons);

// add centroid to it
if (by_residual) {
std::vector<float> centroid(d);
quantizer->reconstruct(list_no, centroid.data());
for (int i = 0; i < d; ++i) {
recons[i] += centroid[i];
}
}
sa_decode(1, code.data(), recons);
}

void IndexIVFFastScan::reconstruct_orig_invlists() {
FAISS_THROW_IF_NOT(orig_invlists != nullptr);
FAISS_THROW_IF_NOT(orig_invlists->list_size(0) == 0);

for (size_t list_no = 0; list_no < nlist; list_no++) {
#pragma omp parallel for if (nlist > 100)
for (idx_t list_no = 0; list_no < nlist; list_no++) {
InvertedLists::ScopedCodes codes(invlists, list_no);
InvertedLists::ScopedIds ids(invlists, list_no);
size_t list_size = orig_invlists->list_size(list_no);
size_t list_size = invlists->list_size(list_no);
std::vector<uint8_t> code(code_size, 0);

for (size_t offset = 0; offset < list_size; offset++) {
Expand All @@ -1400,6 +1400,30 @@ void IndexIVFFastScan::reconstruct_orig_invlists() {
}
}

void IndexIVFFastScan::sa_decode(idx_t n, const uint8_t* codes, float* x)
const {
size_t coarse_size = coarse_code_size();

#pragma omp parallel if (n > 1)
{
std::vector<float> residual(d);

#pragma omp for
for (idx_t i = 0; i < n; i++) {
const uint8_t* code = codes + i * (code_size + coarse_size);
int64_t list_no = decode_listno(code);
float* xi = x + i * d;
fine_quantizer->decode(code + coarse_size, xi, 1);
if (by_residual) {
quantizer->reconstruct(list_no, residual.data());
for (size_t j = 0; j < d; j++) {
xi[j] += residual[j];
}
}
}
}
}

IVFFastScanStats IVFFastScan_stats;

} // namespace faiss
17 changes: 17 additions & 0 deletions faiss/IndexIVFFastScan.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace faiss {

struct NormTableScaler;
struct SIMDResultHandlerToFloat;
struct Quantizer;

/** Fast scan version of IVFPQ and IVFAQ. Works for 4-bit PQ/AQ for now.
*
Expand Down Expand Up @@ -59,6 +60,9 @@ struct IndexIVFFastScan : IndexIVF {
int qbs = 0;
size_t qbs2 = 0;

// quantizer used to pack the codes
Quantizer* fine_quantizer = nullptr;

IndexIVFFastScan(
Index* quantizer,
size_t d,
Expand All @@ -68,7 +72,9 @@ struct IndexIVFFastScan : IndexIVF {

IndexIVFFastScan();

/// called by implementations
void init_fastscan(
Quantizer* fine_quantizer,
size_t M,
size_t nbits,
size_t nlist,
Expand Down Expand Up @@ -225,6 +231,17 @@ struct IndexIVFFastScan : IndexIVF {

// reconstruct orig invlists (for debugging)
void reconstruct_orig_invlists();

/** Decode a set of vectors.
*
* NOTE: The codes in the IndexFastScan object are non-contiguous.
* But this method requires a contiguous representation.
*
* @param n number of vectors
* @param bytes input encoded vectors, size n * code_size
* @param x output vectors, size n * d
*/
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
};

struct IVFFastScanStats {
Expand Down
32 changes: 5 additions & 27 deletions faiss/IndexIVFPQFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(
: IndexIVFFastScan(quantizer, d, nlist, 0, metric), pq(d, M, nbits) {
by_residual = false; // set to false by default because it's faster

init_fastscan(M, nbits, nlist, metric, bbs);
init_fastscan(&pq, M, nbits, nlist, metric, bbs);
}

IndexIVFPQFastScan::IndexIVFPQFastScan() {
Expand All @@ -61,7 +61,8 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
pq(orig.pq) {
FAISS_THROW_IF_NOT(orig.pq.nbits == 4);

init_fastscan(orig.pq.M, orig.pq.nbits, orig.nlist, orig.metric_type, bbs);
init_fastscan(
&pq, orig.pq.M, orig.pq.nbits, orig.nlist, orig.metric_type, bbs);

by_residual = orig.by_residual;
ntotal = orig.ntotal;
Expand All @@ -76,7 +77,8 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
precomputed_table.nbytes());
}

for (size_t i = 0; i < nlist; i++) {
#pragma omp parallel for if (nlist > 100)
for (idx_t i = 0; i < nlist; i++) {
size_t nb = orig.invlists->list_size(i);
size_t nb2 = roundup(nb, bbs);
AlignedTable<uint8_t> tmp(nb2 * M2 / 2);
Expand Down Expand Up @@ -282,28 +284,4 @@ void IndexIVFPQFastScan::compute_LUT(
}
}

void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* codes, float* x)
const {
size_t coarse_size = coarse_code_size();

#pragma omp parallel if (n > 1)
{
std::vector<float> residual(d);

#pragma omp for
for (idx_t i = 0; i < n; i++) {
const uint8_t* code = codes + i * (code_size + coarse_size);
int64_t list_no = decode_listno(code);
float* xi = x + i * d;
pq.decode(code + coarse_size, xi);
if (by_residual) {
quantizer->reconstruct(list_no, residual.data());
for (size_t j = 0; j < d; j++) {
xi[j] += residual[j];
}
}
}
}
}

} // namespace faiss
2 changes: 0 additions & 2 deletions faiss/IndexIVFPQFastScan.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ struct IndexIVFPQFastScan : IndexIVFFastScan {
const CoarseQuantized& cq,
AlignedTable<float>& dis_tables,
AlignedTable<float>& biases) const override;

void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
};

} // namespace faiss
9 changes: 0 additions & 9 deletions faiss/IndexPQFastScan.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,6 @@ struct IndexPQFastScan : IndexFastScan {

void compute_float_LUT(float* lut, idx_t n, const float* x) const override;

/** Decode a set of vectors.
*
* NOTE: The codes in the IndexPQFastScan object are non-contiguous.
* But this method requires a contiguous representation.
*
* @param n number of vectors
* @param bytes input encoded vectors, size n * code_size
* @param x output vectors, size n * d
*/
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
};

Expand Down
95 changes: 95 additions & 0 deletions tests/test_fast_scan_ivf.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,101 @@ def test_by_residual_odd_dim(self):
self.do_test(by_residual=True, d=30)


class TestReconstruct(unittest.TestCase):
""" test reconstruct and sa_encode / sa_decode
(also for a few additive quantizer variants) """

def do_test(self, by_residual=False):
d = 32
metric = faiss.METRIC_L2

ds = datasets.SyntheticDataset(d, 250, 200, 10)

index = faiss.IndexIVFPQFastScan(
faiss.IndexFlatL2(d), d, 50, d // 2, 4, metric)
index.by_residual = by_residual
index.make_direct_map(True)
index.train(ds.get_train())
index.add(ds.get_database())

# Test reconstruction
v123 = index.reconstruct(123) # single id
v120_10 = index.reconstruct_n(120, 10)
np.testing.assert_array_equal(v120_10[3], v123)
v120_10 = index.reconstruct_batch(np.arange(120, 130))
np.testing.assert_array_equal(v120_10[3], v123)

# Test original list reconstruction
index.orig_invlists = faiss.ArrayInvertedLists(
index.nlist, index.code_size)
index.reconstruct_orig_invlists()
assert index.orig_invlists.compute_ntotal() == index.ntotal

# compare with non fast-scan index
index2 = faiss.IndexIVFPQ(
index.quantizer, d, 50, d // 2, 4, metric)
index2.by_residual = by_residual
index2.pq = index.pq
index2.is_trained = True
index2.replace_invlists(index.orig_invlists, False)
index2.ntotal = index.ntotal
index2.make_direct_map(True)
assert np.all(index.reconstruct(123) == index2.reconstruct(123))

def test_no_residual(self):
self.do_test(by_residual=False)

def test_by_residual(self):
self.do_test(by_residual=True)

def do_test_generic(self, factory_string,
by_residual=False, metric=faiss.METRIC_L2):
d = 32
ds = datasets.SyntheticDataset(d, 250, 200, 10)
index = faiss.index_factory(ds.d, factory_string, metric)
if "IVF" in factory_string:
index.by_residual = by_residual
index.make_direct_map(True)
index.train(ds.get_train())
index.add(ds.get_database())

# Test reconstruction
v123 = index.reconstruct(123) # single id
v120_10 = index.reconstruct_n(120, 10)
np.testing.assert_array_equal(v120_10[3], v123)
v120_10 = index.reconstruct_batch(np.arange(120, 130))
np.testing.assert_array_equal(v120_10[3], v123)
codes = index.sa_encode(ds.get_database()[120:130])
np.testing.assert_array_equal(index.sa_decode(codes), v120_10)

# make sure pointers are correct after serialization
index2 = faiss.deserialize_index(faiss.serialize_index(index))
codes2 = index2.sa_encode(ds.get_database()[120:130])
np.testing.assert_array_equal(codes, codes2)


def test_ivfpq_residual(self):
self.do_test_generic("IVF20,PQ16x4fs", by_residual=True)

def test_ivfpq_no_residual(self):
self.do_test_generic("IVF20,PQ16x4fs", by_residual=False)

def test_pq(self):
self.do_test_generic("PQ16x4fs")

def test_rq(self):
self.do_test_generic("RQ4x4fs", metric=faiss.METRIC_INNER_PRODUCT)

def test_ivfprq(self):
self.do_test_generic("IVF20,PRQ8x2x4fs", by_residual=True, metric=faiss.METRIC_INNER_PRODUCT)

def test_ivfprq_no_residual(self):
self.do_test_generic("IVF20,PRQ8x2x4fs", by_residual=False, metric=faiss.METRIC_INNER_PRODUCT)

def test_prq(self):
self.do_test_generic("PRQ8x2x4fs", metric=faiss.METRIC_INNER_PRODUCT)


class TestIsTrained(unittest.TestCase):

def test_issue_2019(self):
Expand Down
Loading