Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions contrib/inspect_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def get_flat_data(index):
return xb.reshape(index.ntotal, index.d)


def get_flat_codes(index_flat):
""" get the codes from an indexFlatCodes as an array """
return faiss.vector_to_array(index_flat.codes).reshape(
index_flat.ntotal, index_flat.code_size)


def get_NSG_neighbors(nsg):
""" get the neighbor list for the vectors stored in the NSG structure, as
a N-by-K matrix of indices """
Expand Down
164 changes: 159 additions & 5 deletions faiss/IndexFlatCodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/ResultHandler.h>
#include <faiss/utils/extra_distances.h>

namespace faiss {

Expand Down Expand Up @@ -70,11 +72,6 @@ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
reconstruct_n(key, 1, recons);
}

FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
const {
FAISS_THROW_MSG("not implemented");
}

void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
// minimal sanity checks
const IndexFlatCodes* other =
Expand Down Expand Up @@ -114,4 +111,161 @@ void IndexFlatCodes::permute_entries(const idx_t* perm) {
std::swap(codes, new_codes);
}

namespace {

template <class VD>
struct GenericFlatCodesDistanceComputer : FlatCodesDistanceComputer {
const IndexFlatCodes& codec;
const VD vd;
// temp buffers
std::vector<uint8_t> code_buffer;
std::vector<float> vec_buffer;
const float* query = nullptr;

GenericFlatCodesDistanceComputer(const IndexFlatCodes* codec, const VD& vd)
: FlatCodesDistanceComputer(codec->codes.data(), codec->code_size),
codec(*codec),
vd(vd),
code_buffer(codec->code_size * 4),
vec_buffer(codec->d * 4) {}

void set_query(const float* x) override {
query = x;
}

float operator()(idx_t i) override {
codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
return vd(query, vec_buffer.data());
}

float distance_to_code(const uint8_t* code) override {
codec.sa_decode(1, code, vec_buffer.data());
return vd(query, vec_buffer.data());
}

float symmetric_dis(idx_t i, idx_t j) override {
codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
codec.sa_decode(1, codes + j * code_size, vec_buffer.data() + vd.d);
return vd(vec_buffer.data(), vec_buffer.data() + vd.d);
}

void distances_batch_4(
const idx_t idx0,
const idx_t idx1,
const idx_t idx2,
const idx_t idx3,
float& dis0,
float& dis1,
float& dis2,
float& dis3) override {
uint8_t* cp = code_buffer.data();
for (idx_t i : {idx0, idx1, idx2, idx3}) {
memcpy(cp, codes + i * code_size, code_size);
Copy link
Contributor

@alexanderguzhva alexanderguzhva Jul 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's assume that the code size is 1536 floats. This means that every distances_batch_4() function call costs extra 4 * 4 * 1536 bytes read and written. If I understand correctly, it should affect the performance noticeably.

cp += code_size;
}
// potential benefit is if batch decoding is more efficient than 1 by 1
// decoding
codec.sa_decode(4, code_buffer.data(), vec_buffer.data());
dis0 = vd(query, vec_buffer.data());
dis1 = vd(query, vec_buffer.data() + vd.d);
dis2 = vd(query, vec_buffer.data() + 2 * vd.d);
dis3 = vd(query, vec_buffer.data() + 3 * vd.d);
}
};

struct Run_get_distance_computer {
using T = FlatCodesDistanceComputer*;

template <class VD>
FlatCodesDistanceComputer* f(const VD& vd, const IndexFlatCodes* codec) {
return new GenericFlatCodesDistanceComputer<VD>(codec, vd);
}
};

template <class BlockResultHandler>
struct Run_search_with_decompress {
using T = void;

template <class VectorDistance>
void f(VectorDistance& vd,
const IndexFlatCodes* index_ptr,
const float* xq,
BlockResultHandler& res) {
// Note that there seems to be a clang (?) bug that "sometimes" passes
// the const Index & parameters by value, so to be on the safe side,
// it's better to use pointers.
const IndexFlatCodes& index = *index_ptr;
size_t ntotal = index.ntotal;
using SingleResultHandler =
typename BlockResultHandler::SingleResultHandler;
using DC = GenericFlatCodesDistanceComputer<VectorDistance>;
#pragma omp parallel // if (res.nq > 100)
{
std::unique_ptr<DC> dc(new DC(&index, vd));
SingleResultHandler resi(res);
#pragma omp for
for (int64_t q = 0; q < res.nq; q++) {
resi.begin(q);
dc->set_query(xq + vd.d * q);
for (size_t i = 0; i < ntotal; i++) {
if (res.is_in_selection(i)) {
float dis = (*dc)(i);
resi.add_result(dis, i);
}
}
resi.end();
}
}
}
};

struct Run_search_with_decompress_res {
using T = void;

template <class ResultHandler>
void f(ResultHandler& res, const IndexFlatCodes* index, const float* xq) {
Run_search_with_decompress<ResultHandler> r;
dispatch_VectorDistance(
index->d,
index->metric_type,
index->metric_arg,
r,
index,
xq,
res);
}
};

} // anonymous namespace

FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
const {
Run_get_distance_computer r;
return dispatch_VectorDistance(d, metric_type, metric_arg, r, this);
}

void IndexFlatCodes::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params) const {
Run_search_with_decompress_res r;
const IDSelector* sel = params ? params->sel : nullptr;
dispatch_knn_ResultHandler(
n, distances, labels, k, metric_type, sel, r, this, x);
}

void IndexFlatCodes::range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params) const {
const IDSelector* sel = params ? params->sel : nullptr;
Run_search_with_decompress_res r;
dispatch_range_ResultHandler(result, radius, metric_type, sel, r, this, x);
}

} // namespace faiss
23 changes: 20 additions & 3 deletions faiss/IndexFlatCodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#pragma once

#include <faiss/Index.h>
Expand Down Expand Up @@ -45,13 +43,32 @@ struct IndexFlatCodes : Index {
* different from the usual ones: the new ids are shifted */
size_t remove_ids(const IDSelector& sel) override;

/** a FlatCodesDistanceComputer offers a distance_to_code method */
/** a FlatCodesDistanceComputer offers a distance_to_code method
*
* The default implementation explicitly decodes the vector with sa_decode.
*/
virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const;

DistanceComputer* get_distance_computer() const override {
return get_FlatCodesDistanceComputer();
}

/** Search implemented by decoding */
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params = nullptr) const override;

void range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params = nullptr) const override;

// returns a new instance of a CodePacker
CodePacker* get_CodePacker() const;

Expand Down
20 changes: 1 addition & 19 deletions faiss/IndexLattice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
namespace faiss {

IndexLattice::IndexLattice(idx_t d, int nsq, int scale_nbit, int r2)
: Index(d),
: IndexFlatCodes(0, d, METRIC_L2),
nsq(nsq),
dsq(d / nsq),
zn_sphere_codec(dsq, r2),
Expand Down Expand Up @@ -114,22 +114,4 @@ void IndexLattice::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
}
}

void IndexLattice::add(idx_t, const float*) {
FAISS_THROW_MSG("not implemented");
}

void IndexLattice::search(
idx_t,
const float*,
idx_t,
float*,
idx_t*,
const SearchParameters*) const {
FAISS_THROW_MSG("not implemented");
}

void IndexLattice::reset() {
FAISS_THROW_MSG("not implemented");
}

} // namespace faiss
25 changes: 3 additions & 22 deletions faiss/IndexLattice.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,18 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#ifndef FAISS_INDEX_LATTICE_H
#define FAISS_INDEX_LATTICE_H
#pragma once

#include <vector>

#include <faiss/IndexIVF.h>
#include <faiss/IndexFlatCodes.h>
#include <faiss/impl/lattice_Zn.h>

namespace faiss {

/** Index that encodes a vector with a series of Zn lattice quantizers
*/
struct IndexLattice : Index {
struct IndexLattice : IndexFlatCodes {
/// number of sub-vectors
int nsq;
/// dimension of sub-vectors
Expand All @@ -30,8 +27,6 @@ struct IndexLattice : Index {

/// nb bits used to encode the scale, per subvector
int scale_nbit, lattice_nbit;
/// total, in bytes
size_t code_size;

/// mins and maxes of the vector norms, per subquantizer
std::vector<float> trained;
Expand All @@ -46,20 +41,6 @@ struct IndexLattice : Index {
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;

void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;

/// not implemented
void add(idx_t n, const float* x) override;
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params = nullptr) const override;

void reset() override;
};

} // namespace faiss

#endif
Loading