Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions faiss/IndexAdditiveQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ void IndexAdditiveQuantizer::search(
DISPATCH(ST_norm_qint8)
DISPATCH(ST_norm_qint4)
DISPATCH(ST_norm_cqint4)
DISPATCH(ST_norm_from_LUT)
case AdditiveQuantizer::ST_norm_cqint8:
case AdditiveQuantizer::ST_norm_lsq2x4:
case AdditiveQuantizer::ST_norm_rq2x4:
Expand Down
2 changes: 1 addition & 1 deletion faiss/IndexIVFAdditiveQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ InvertedListScanner* IndexIVFAdditiveQuantizer::get_InvertedListScanner(
return new AQInvertedListScannerLUT<false, AdditiveQuantizer::st>( \
*this, store_pairs);
A(ST_LUT_nonorm)
// A(ST_norm_from_LUT)
A(ST_norm_from_LUT)
A(ST_norm_float)
A(ST_norm_qint8)
A(ST_norm_qint4)
Expand Down
85 changes: 77 additions & 8 deletions faiss/impl/AdditiveQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,40 @@ void AdditiveQuantizer::train_norm(size_t n, const float* norms) {
}
}

void AdditiveQuantizer::compute_codebook_tables() {
centroid_norms.resize(total_codebook_size);
fvec_norms_L2sqr(
centroid_norms.data(), codebooks.data(), d, total_codebook_size);
size_t cross_table_size = 0;
for (int m = 0; m < M; m++) {
size_t K = (size_t)1 << nbits[m];
cross_table_size += K * codebook_offsets[m];
}
codebook_cross_products.resize(cross_table_size);
size_t ofs = 0;
for (int m = 1; m < M; m++) {
FINTEGER ki = (size_t)1 << nbits[m];
FINTEGER kk = codebook_offsets[m];
FINTEGER di = d;
float zero = 0, one = 1;
assert(ofs + ki * kk <= cross_table_size);
sgemm_("Transposed",
"Not transposed",
&ki,
&kk,
&di,
&one,
codebooks.data() + d * kk,
&di,
codebooks.data(),
&di,
&zero,
codebook_cross_products.data() + ofs,
&ki);
ofs += ki * kk;
}
}

namespace {

// TODO
Expand Down Expand Up @@ -471,7 +505,6 @@ namespace {
float accumulate_IPs(
const AdditiveQuantizer& aq,
BitstringReader& bs,
const uint8_t* codes,
const float* LUT) {
float accu = 0;
for (int m = 0; m < aq.M; m++) {
Expand All @@ -483,6 +516,29 @@ float accumulate_IPs(
return accu;
}

float compute_norm_from_LUT(const AdditiveQuantizer& aq, BitstringReader& bs) {
float accu = 0;
std::vector<int> idx(aq.M);
const float* c = aq.codebook_cross_products.data();
for (int m = 0; m < aq.M; m++) {
size_t nbit = aq.nbits[m];
int i = bs.read(nbit);
size_t K = 1 << nbit;
idx[m] = i;

accu += aq.centroid_norms[aq.codebook_offsets[m] + i];

for (int l = 0; l < m; l++) {
int j = idx[l];
accu += 2 * c[j * K + i];
c += (1 << aq.nbits[l]) * K;
}
}
// FAISS_THROW_IF_NOT(c == aq.codebook_cross_products.data() +
// aq.codebook_cross_products.size());
return accu;
}

} // anonymous namespace

template <>
Expand All @@ -491,7 +547,7 @@ float AdditiveQuantizer::
const uint8_t* codes,
const float* LUT) const {
BitstringReader bs(codes, code_size);
return accumulate_IPs(*this, bs, codes, LUT);
return accumulate_IPs(*this, bs, LUT);
}

template <>
Expand All @@ -500,7 +556,7 @@ float AdditiveQuantizer::
const uint8_t* codes,
const float* LUT) const {
BitstringReader bs(codes, code_size);
return -accumulate_IPs(*this, bs, codes, LUT);
return -accumulate_IPs(*this, bs, LUT);
}

template <>
Expand All @@ -509,7 +565,7 @@ float AdditiveQuantizer::
const uint8_t* codes,
const float* LUT) const {
BitstringReader bs(codes, code_size);
float accu = accumulate_IPs(*this, bs, codes, LUT);
float accu = accumulate_IPs(*this, bs, LUT);
uint32_t norm_i = bs.read(32);
float norm2;
memcpy(&norm2, &norm_i, 4);
Expand All @@ -522,7 +578,7 @@ float AdditiveQuantizer::
const uint8_t* codes,
const float* LUT) const {
BitstringReader bs(codes, code_size);
float accu = accumulate_IPs(*this, bs, codes, LUT);
float accu = accumulate_IPs(*this, bs, LUT);
uint32_t norm_i = bs.read(8);
float norm2 = decode_qcint(norm_i);
return norm2 - 2 * accu;
Expand All @@ -534,7 +590,7 @@ float AdditiveQuantizer::
const uint8_t* codes,
const float* LUT) const {
BitstringReader bs(codes, code_size);
float accu = accumulate_IPs(*this, bs, codes, LUT);
float accu = accumulate_IPs(*this, bs, LUT);
uint32_t norm_i = bs.read(4);
float norm2 = decode_qcint(norm_i);
return norm2 - 2 * accu;
Expand All @@ -546,7 +602,7 @@ float AdditiveQuantizer::
const uint8_t* codes,
const float* LUT) const {
BitstringReader bs(codes, code_size);
float accu = accumulate_IPs(*this, bs, codes, LUT);
float accu = accumulate_IPs(*this, bs, LUT);
uint32_t norm_i = bs.read(8);
float norm2 = decode_qint8(norm_i, norm_min, norm_max);
return norm2 - 2 * accu;
Expand All @@ -558,10 +614,23 @@ float AdditiveQuantizer::
const uint8_t* codes,
const float* LUT) const {
BitstringReader bs(codes, code_size);
float accu = accumulate_IPs(*this, bs, codes, LUT);
float accu = accumulate_IPs(*this, bs, LUT);
uint32_t norm_i = bs.read(4);
float norm2 = decode_qint4(norm_i, norm_min, norm_max);
return norm2 - 2 * accu;
}

template <>
float AdditiveQuantizer::
compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_from_LUT>(
const uint8_t* codes,
const float* LUT) const {
FAISS_THROW_IF_NOT(codebook_cross_products.size() > 0);
BitstringReader bs(codes, code_size);
float accu = accumulate_IPs(*this, bs, LUT);
BitstringReader bs2(codes, code_size);
float norm2 = compute_norm_from_LUT(*this, bs2);
return norm2 - 2 * accu;
}

} // namespace faiss
18 changes: 15 additions & 3 deletions faiss/impl/AdditiveQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ struct AdditiveQuantizer : Quantizer {
std::vector<float> codebooks; ///< codebooks

// derived values
/// codebook #1 is stored in rows codebook_offsets[i]:codebook_offsets[i+1]
/// in the codebooks table of size total_codebook_size by d
std::vector<uint64_t> codebook_offsets;
size_t tot_bits = 0; ///< total number of bits (indexes + norms)
size_t norm_bits = 0; ///< bits allocated for the norms
Expand All @@ -38,9 +40,19 @@ struct AdditiveQuantizer : Quantizer {
bool verbose = false; ///< verbose during training?
bool is_trained = false; ///< is trained or not

IndexFlat1D qnorm; ///< store and search norms
std::vector<float> norm_tabs; ///< store norms of codebook entries for 4-bit
///< fastscan search
/// auxiliary data for ST_norm_lsq2x4 and ST_norm_rq2x4
/// store norms of codebook entries for 4-bit fastscan
std::vector<float> norm_tabs;
IndexFlat1D qnorm; ///< store and search norms

void compute_codebook_tables();

/// norms of all codebook entries (size total_codebook_size)
std::vector<float> centroid_norms;

/// dot products of all codebook entries with the previous codebooks
/// size sum(codebook_offsets[m] * 2^nbits[m], m=0..M-1)
std::vector<float> codebook_cross_products;

/// norms and distance matrixes with beam search can get large, so use this
/// to control for the amount of memory that can be allocated
Expand Down
34 changes: 0 additions & 34 deletions faiss/impl/ResidualQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,40 +492,6 @@ void ResidualQuantizer::refine_beam(
* Functions using the dot products between codebook entries
*******************************************************************/

void ResidualQuantizer::compute_codebook_tables() {
cent_norms.resize(total_codebook_size);
fvec_norms_L2sqr(
cent_norms.data(), codebooks.data(), d, total_codebook_size);
size_t cross_table_size = 0;
for (int m = 0; m < M; m++) {
size_t K = (size_t)1 << nbits[m];
cross_table_size += K * codebook_offsets[m];
}
codebook_cross_products.resize(cross_table_size);
size_t ofs = 0;
for (int m = 1; m < M; m++) {
FINTEGER ki = (size_t)1 << nbits[m];
FINTEGER kk = codebook_offsets[m];
FINTEGER di = d;
float zero = 0, one = 1;
assert(ofs + ki * kk <= cross_table_size);
sgemm_("Transposed",
"Not transposed",
&ki,
&kk,
&di,
&one,
codebooks.data() + d * kk,
&di,
codebooks.data(),
&di,
&zero,
codebook_cross_products.data() + ofs,
&ki);
ofs += ki * kk;
}
}

void ResidualQuantizer::refine_beam_LUT(
size_t n,
const float* query_norms, // size n
Expand Down
10 changes: 0 additions & 10 deletions faiss/impl/ResidualQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,6 @@ struct ResidualQuantizer : AdditiveQuantizer {
* @param beam_size if != -1, override the beam size
*/
size_t memory_per_point(int beam_size = -1) const;

/** Cross products used in codebook tables used for beam_LUT = 1
*/
void compute_codebook_tables();

/// dot products of all codebook entries with the previous codebooks
/// size sum(codebook_offsets[m] * 2^nbits[m], m=0..M-1)
std::vector<float> codebook_cross_products;
/// norms of all codebook entries (size total_codebook_size)
std::vector<float> cent_norms;
};

} // namespace faiss
2 changes: 1 addition & 1 deletion faiss/impl/residual_quantizer_encode_steps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ void refine_beam_LUT_mp(
rq.codebook_offsets.data(),
query_cp + rq.codebook_offsets[m],
rq.total_codebook_size,
rq.cent_norms.data() + rq.codebook_offsets[m],
rq.centroid_norms.data() + rq.codebook_offsets[m],
m,
codes_ptr,
distances_ptr,
Expand Down
36 changes: 35 additions & 1 deletion tests/test_residual_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,37 @@ def test_search_decompress(self):
# recalls are {1: 0.05, 10: 0.37, 100: 0.37}
self.assertGreater(recalls[10], 0.35)

def do_exact_search_equiv(self, norm_type):
""" searching with this normalization should yield
exactly the same results as decompression (because the
norms are exact) """
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)

# decompresses by default
ir = faiss.IndexResidualQuantizer(ds.d, 3, 6)
ir.rq.train_type = faiss.ResidualQuantizer.Train_default
ir.train(ds.get_train())
ir.add(ds.get_database())
Dref, Iref = ir.search(ds.get_queries(), 10)

ir2 = faiss.IndexResidualQuantizer(
ds.d, 3, 6, faiss.METRIC_L2, norm_type)

# assumes training is reproducible
ir2.rq.train_type = faiss.ResidualQuantizer.Train_default
ir2.train(ds.get_train())
ir2.add(ds.get_database())
D, I = ir2.search(ds.get_queries(), 10)

np.testing.assert_allclose(D, Dref, atol=1e-5)
np.testing.assert_array_equal(I, Iref)

def test_exact_equiv_norm_float(self):
self.do_exact_search_equiv(faiss.AdditiveQuantizer.ST_norm_float)

def test_exact_equiv_norm_from_LUT(self):
self.do_exact_search_equiv(faiss.AdditiveQuantizer.ST_norm_from_LUT)

def test_reestimate_codebook(self):
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)

Expand Down Expand Up @@ -858,6 +889,9 @@ def test_norm_cqint(self):
self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_cqint8)
self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_cqint4)

def test_norm_from_LUT(self):
self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_from_LUT)

def test_factory(self):
index = faiss.index_factory(12, "IVF1024,RQ8x8_Nfloat")
self.assertEqual(index.nlist, 1024)
Expand Down Expand Up @@ -1105,7 +1139,7 @@ def test_precomp(self):
ofs += kk * K
np.testing.assert_allclose(py_table, cpp_table, atol=1e-5)

cent_norms = faiss.vector_to_array(rq.cent_norms)
cent_norms = faiss.vector_to_array(rq.centroid_norms)
np.testing.assert_array_almost_equal(
np.hstack(cent_norms_ref), cent_norms, decimal=5)

Expand Down