Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ set(FAISS_HEADERS
utils/hamming.h
utils/ordered_key_value.h
utils/partitioning.h
utils/prefetch.h
utils/quantize_lut.h
utils/random.h
utils/simdlib.h
Expand Down
198 changes: 195 additions & 3 deletions faiss/IndexFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <faiss/utils/Heap.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/extra_distances.h>
#include <faiss/utils/prefetch.h>
#include <faiss/utils/sorting.h>
#include <faiss/utils/utils.h>
#include <cstring>
Expand Down Expand Up @@ -122,6 +123,39 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
void set_query(const float* x) override {
q = x;
}

// compute four distances
void distances_batch_4(
const idx_t idx0,
const idx_t idx1,
const idx_t idx2,
const idx_t idx3,
float& dis0,
float& dis1,
float& dis2,
float& dis3) final override {
ndis += 4;

// compute first, assign next
const float* __restrict y0 =
reinterpret_cast<const float*>(codes + idx0 * code_size);
const float* __restrict y1 =
reinterpret_cast<const float*>(codes + idx1 * code_size);
const float* __restrict y2 =
reinterpret_cast<const float*>(codes + idx2 * code_size);
const float* __restrict y3 =
reinterpret_cast<const float*>(codes + idx3 * code_size);

float dp0 = 0;
float dp1 = 0;
float dp2 = 0;
float dp3 = 0;
fvec_L2sqr_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
dis0 = dp0;
dis1 = dp1;
dis2 = dp2;
dis3 = dp3;
}
};

struct FlatIPDis : FlatCodesDistanceComputer {
Expand All @@ -131,13 +165,13 @@ struct FlatIPDis : FlatCodesDistanceComputer {
const float* b;
size_t ndis;

float symmetric_dis(idx_t i, idx_t j) override {
float symmetric_dis(idx_t i, idx_t j) final override {
return fvec_inner_product(b + j * d, b + i * d, d);
}

float distance_to_code(const uint8_t* code) final {
float distance_to_code(const uint8_t* code) final override {
ndis++;
return fvec_inner_product(q, (float*)code, d);
return fvec_inner_product(q, (const float*)code, d);
}

explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
Expand All @@ -153,6 +187,39 @@ struct FlatIPDis : FlatCodesDistanceComputer {
void set_query(const float* x) override {
q = x;
}

// compute four distances
void distances_batch_4(
const idx_t idx0,
const idx_t idx1,
const idx_t idx2,
const idx_t idx3,
float& dis0,
float& dis1,
float& dis2,
float& dis3) final override {
ndis += 4;

// compute first, assign next
const float* __restrict y0 =
reinterpret_cast<const float*>(codes + idx0 * code_size);
const float* __restrict y1 =
reinterpret_cast<const float*>(codes + idx1 * code_size);
const float* __restrict y2 =
reinterpret_cast<const float*>(codes + idx2 * code_size);
const float* __restrict y3 =
reinterpret_cast<const float*>(codes + idx3 * code_size);

float dp0 = 0;
float dp1 = 0;
float dp2 = 0;
float dp3 = 0;
fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
dis0 = dp0;
dis1 = dp1;
dis2 = dp2;
dis3 = dp3;
}
};

} // namespace
Expand Down Expand Up @@ -184,6 +251,131 @@ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
}
}

/***************************************************
* IndexFlatL2
***************************************************/

namespace {
struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
size_t d;
idx_t nb;
const float* q;
const float* b;
size_t ndis;

const float* l2norms;
float query_l2norm;

float distance_to_code(const uint8_t* code) final override {
ndis++;
return fvec_L2sqr(q, (float*)code, d);
}

float operator()(const idx_t i) final override {
const float* __restrict y =
reinterpret_cast<const float*>(codes + i * code_size);

prefetch_L2(l2norms + i);
const float dp0 = fvec_inner_product(q, y, d);
return query_l2norm + l2norms[i] - 2 * dp0;
}

float symmetric_dis(idx_t i, idx_t j) final override {
const float* __restrict yi =
reinterpret_cast<const float*>(codes + i * code_size);
const float* __restrict yj =
reinterpret_cast<const float*>(codes + j * code_size);

prefetch_L2(l2norms + i);
prefetch_L2(l2norms + j);
const float dp0 = fvec_inner_product(yi, yj, d);
return l2norms[i] + l2norms[j] - 2 * dp0;
}

explicit FlatL2WithNormsDis(
const IndexFlatL2& storage,
const float* q = nullptr)
: FlatCodesDistanceComputer(
storage.codes.data(),
storage.code_size),
d(storage.d),
nb(storage.ntotal),
q(q),
b(storage.get_xb()),
ndis(0),
l2norms(storage.cached_l2norms.data()),
query_l2norm(0) {}

void set_query(const float* x) override {
q = x;
query_l2norm = fvec_norm_L2sqr(q, d);
}

// compute four distances
void distances_batch_4(
const idx_t idx0,
const idx_t idx1,
const idx_t idx2,
const idx_t idx3,
float& dis0,
float& dis1,
float& dis2,
float& dis3) final override {
ndis += 4;

// compute first, assign next
const float* __restrict y0 =
reinterpret_cast<const float*>(codes + idx0 * code_size);
const float* __restrict y1 =
reinterpret_cast<const float*>(codes + idx1 * code_size);
const float* __restrict y2 =
reinterpret_cast<const float*>(codes + idx2 * code_size);
const float* __restrict y3 =
reinterpret_cast<const float*>(codes + idx3 * code_size);

prefetch_L2(l2norms + idx0);
prefetch_L2(l2norms + idx1);
prefetch_L2(l2norms + idx2);
prefetch_L2(l2norms + idx3);

float dp0 = 0;
float dp1 = 0;
float dp2 = 0;
float dp3 = 0;
fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
dis0 = query_l2norm + l2norms[idx0] - 2 * dp0;
dis1 = query_l2norm + l2norms[idx1] - 2 * dp1;
dis2 = query_l2norm + l2norms[idx2] - 2 * dp2;
dis3 = query_l2norm + l2norms[idx3] - 2 * dp3;
}
};

} // namespace

void IndexFlatL2::sync_l2norms() {
cached_l2norms.resize(ntotal);
fvec_norms_L2sqr(
cached_l2norms.data(),
reinterpret_cast<const float*>(codes.data()),
d,
ntotal);
}

void IndexFlatL2::clear_l2norms() {
cached_l2norms.clear();
cached_l2norms.shrink_to_fit();
}

FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const {
if (metric_type == METRIC_L2) {
if (!cached_l2norms.empty()) {
return new FlatL2WithNormsDis(*this);
}
}

return IndexFlat::get_FlatCodesDistanceComputer();
}

/***************************************************
* IndexFlat1D
***************************************************/
Expand Down
14 changes: 14 additions & 0 deletions faiss/IndexFlat.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,22 @@ struct IndexFlatIP : IndexFlat {
};

struct IndexFlatL2 : IndexFlat {
// Special cache for L2 norms.
// If this cache is set, then get_distance_computer() returns
// a special version that computes the distance using dot products
// and l2 norms.
std::vector<float> cached_l2norms;

explicit IndexFlatL2(idx_t d) : IndexFlat(d, METRIC_L2) {}
IndexFlatL2() {}

// override for l2 norms cache.
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;

// compute L2 norms
void sync_l2norms();
// clear L2 norms
void clear_l2norms();
};

/// optimized version for 1D "vectors".
Expand Down
5 changes: 4 additions & 1 deletion faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,10 @@ IndexHNSWFlat::IndexHNSWFlat() {
}

IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
: IndexHNSW(new IndexFlat(d, metric), M) {
: IndexHNSW(
(metric == METRIC_L2) ? new IndexFlatL2(d)
: new IndexFlat(d, metric),
M) {
own_fields = true;
is_trained = true;
}
Expand Down
25 changes: 24 additions & 1 deletion faiss/impl/DistanceComputer.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,29 @@ struct DistanceComputer {
/// compute distance of vector i to current query
virtual float operator()(idx_t i) = 0;

/// compute distances of current query to 4 stored vectors.
/// certain DistanceComputer implementations may benefit
/// heavily from this.
virtual void distances_batch_4(
const idx_t idx0,
const idx_t idx1,
const idx_t idx2,
const idx_t idx3,
float& dis0,
float& dis1,
float& dis2,
float& dis3) {
// compute first, assign next
const float d0 = this->operator()(idx0);
const float d1 = this->operator()(idx1);
const float d2 = this->operator()(idx2);
const float d3 = this->operator()(idx3);
dis0 = d0;
dis1 = d1;
dis2 = d2;
dis3 = d3;
}

/// compute distance between two stored vectors
virtual float symmetric_dis(idx_t i, idx_t j) = 0;

Expand All @@ -49,7 +72,7 @@ struct FlatCodesDistanceComputer : DistanceComputer {

FlatCodesDistanceComputer() : codes(nullptr), code_size(0) {}

float operator()(idx_t i) final {
float operator()(idx_t i) override {
return distance_to_code(codes + i * code_size);
}

Expand Down
Loading