diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index 1fea676ca9..5e635a53e8 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -190,6 +190,7 @@ set(FAISS_HEADERS utils/hamming.h utils/ordered_key_value.h utils/partitioning.h + utils/prefetch.h utils/quantize_lut.h utils/random.h utils/simdlib.h diff --git a/faiss/IndexFlat.cpp b/faiss/IndexFlat.cpp index 0fa3b82062..f606f8e621 100644 --- a/faiss/IndexFlat.cpp +++ b/faiss/IndexFlat.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -122,6 +123,39 @@ struct FlatL2Dis : FlatCodesDistanceComputer { void set_query(const float* x) override { q = x; } + + // compute four distances + void distances_batch_4( + const idx_t idx0, + const idx_t idx1, + const idx_t idx2, + const idx_t idx3, + float& dis0, + float& dis1, + float& dis2, + float& dis3) final override { + ndis += 4; + + // compute first, assign next + const float* __restrict y0 = + reinterpret_cast(codes + idx0 * code_size); + const float* __restrict y1 = + reinterpret_cast(codes + idx1 * code_size); + const float* __restrict y2 = + reinterpret_cast(codes + idx2 * code_size); + const float* __restrict y3 = + reinterpret_cast(codes + idx3 * code_size); + + float dp0 = 0; + float dp1 = 0; + float dp2 = 0; + float dp3 = 0; + fvec_L2sqr_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3); + dis0 = dp0; + dis1 = dp1; + dis2 = dp2; + dis3 = dp3; + } }; struct FlatIPDis : FlatCodesDistanceComputer { @@ -131,13 +165,13 @@ struct FlatIPDis : FlatCodesDistanceComputer { const float* b; size_t ndis; - float symmetric_dis(idx_t i, idx_t j) override { + float symmetric_dis(idx_t i, idx_t j) final override { return fvec_inner_product(b + j * d, b + i * d, d); } - float distance_to_code(const uint8_t* code) final { + float distance_to_code(const uint8_t* code) final override { ndis++; - return fvec_inner_product(q, (float*)code, d); + return fvec_inner_product(q, (const float*)code, d); } explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr) @@ -153,6 +187,39 @@ struct FlatIPDis : FlatCodesDistanceComputer { void set_query(const float* x) override { q = x; } + + // compute four distances + void distances_batch_4( + const idx_t idx0, + const idx_t idx1, + const idx_t idx2, + const idx_t idx3, + float& dis0, + float& dis1, + float& dis2, + float& dis3) final override { + ndis += 4; + + // compute first, assign next + const float* __restrict y0 = + reinterpret_cast(codes + idx0 * code_size); + const float* __restrict y1 = + reinterpret_cast(codes + idx1 * code_size); + const float* __restrict y2 = + reinterpret_cast(codes + idx2 * code_size); + const float* __restrict y3 = + reinterpret_cast(codes + idx3 * code_size); + + float dp0 = 0; + float dp1 = 0; + float dp2 = 0; + float dp3 = 0; + fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3); + dis0 = dp0; + dis1 = dp1; + dis2 = dp2; + dis3 = dp3; + } }; } // namespace @@ -184,6 +251,131 @@ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const { } } +/*************************************************** + * IndexFlatL2 + ***************************************************/ + +namespace { +struct FlatL2WithNormsDis : FlatCodesDistanceComputer { + size_t d; + idx_t nb; + const float* q; + const float* b; + size_t ndis; + + const float* l2norms; + float query_l2norm; + + float distance_to_code(const uint8_t* code) final override { + ndis++; + return fvec_L2sqr(q, (float*)code, d); + } + + float operator()(const idx_t i) final override { + const float* __restrict y = + reinterpret_cast(codes + i * code_size); + + prefetch_L2(l2norms + i); + const float dp0 = fvec_inner_product(q, y, d); + return query_l2norm + l2norms[i] - 2 * dp0; + } + + float symmetric_dis(idx_t i, idx_t j) final override { + const float* __restrict yi = + reinterpret_cast(codes + i * code_size); + const float* __restrict yj = + reinterpret_cast(codes + j * code_size); + + prefetch_L2(l2norms + i); + prefetch_L2(l2norms + j); + const float dp0 = fvec_inner_product(yi, yj, d); + return l2norms[i] + l2norms[j] - 2 * dp0; + } + + explicit FlatL2WithNormsDis( + const IndexFlatL2& storage, + const float* q = nullptr) + : FlatCodesDistanceComputer( + storage.codes.data(), + storage.code_size), + d(storage.d), + nb(storage.ntotal), + q(q), + b(storage.get_xb()), + ndis(0), + l2norms(storage.cached_l2norms.data()), + query_l2norm(0) {} + + void set_query(const float* x) override { + q = x; + query_l2norm = fvec_norm_L2sqr(q, d); + } + + // compute four distances + void distances_batch_4( + const idx_t idx0, + const idx_t idx1, + const idx_t idx2, + const idx_t idx3, + float& dis0, + float& dis1, + float& dis2, + float& dis3) final override { + ndis += 4; + + // compute first, assign next + const float* __restrict y0 = + reinterpret_cast(codes + idx0 * code_size); + const float* __restrict y1 = + reinterpret_cast(codes + idx1 * code_size); + const float* __restrict y2 = + reinterpret_cast(codes + idx2 * code_size); + const float* __restrict y3 = + reinterpret_cast(codes + idx3 * code_size); + + prefetch_L2(l2norms + idx0); + prefetch_L2(l2norms + idx1); + prefetch_L2(l2norms + idx2); + prefetch_L2(l2norms + idx3); + + float dp0 = 0; + float dp1 = 0; + float dp2 = 0; + float dp3 = 0; + fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3); + dis0 = query_l2norm + l2norms[idx0] - 2 * dp0; + dis1 = query_l2norm + l2norms[idx1] - 2 * dp1; + dis2 = query_l2norm + l2norms[idx2] - 2 * dp2; + dis3 = query_l2norm + l2norms[idx3] - 2 * dp3; + } +}; + +} // namespace + +void IndexFlatL2::sync_l2norms() { + cached_l2norms.resize(ntotal); + fvec_norms_L2sqr( + cached_l2norms.data(), + reinterpret_cast(codes.data()), + d, + ntotal); +} + +void IndexFlatL2::clear_l2norms() { + cached_l2norms.clear(); + cached_l2norms.shrink_to_fit(); +} + +FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const { + if (metric_type == METRIC_L2) { + if (!cached_l2norms.empty()) { + return new FlatL2WithNormsDis(*this); + } + } + + return IndexFlat::get_FlatCodesDistanceComputer(); +} + /*************************************************** * IndexFlat1D ***************************************************/ diff --git a/faiss/IndexFlat.h b/faiss/IndexFlat.h index ef9910edb6..c2f6eafed7 100644 --- a/faiss/IndexFlat.h +++ b/faiss/IndexFlat.h @@ -76,8 +76,22 @@ struct IndexFlatIP : IndexFlat { }; struct IndexFlatL2 : IndexFlat { + // Special cache for L2 norms. + // If this cache is set, then get_distance_computer() returns + // a special version that computes the distance using dot products + // and l2 norms. + std::vector cached_l2norms; + explicit IndexFlatL2(idx_t d) : IndexFlat(d, METRIC_L2) {} IndexFlatL2() {} + + // override for l2 norms cache. + FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override; + + // compute L2 norms + void sync_l2norms(); + // clear L2 norms + void clear_l2norms(); }; /// optimized version for 1D "vectors". diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 75b8ea9133..78787753e1 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -872,7 +872,10 @@ IndexHNSWFlat::IndexHNSWFlat() { } IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric) - : IndexHNSW(new IndexFlat(d, metric), M) { + : IndexHNSW( + (metric == METRIC_L2) ? new IndexFlatL2(d) + : new IndexFlat(d, metric), + M) { own_fields = true; is_trained = true; } diff --git a/faiss/impl/DistanceComputer.h b/faiss/impl/DistanceComputer.h index 9e7f248d03..dc46d113fb 100644 --- a/faiss/impl/DistanceComputer.h +++ b/faiss/impl/DistanceComputer.h @@ -30,6 +30,29 @@ struct DistanceComputer { /// compute distance of vector i to current query virtual float operator()(idx_t i) = 0; + /// compute distances of current query to 4 stored vectors. + /// certain DistanceComputer implementations may benefit + /// heavily from this. + virtual void distances_batch_4( + const idx_t idx0, + const idx_t idx1, + const idx_t idx2, + const idx_t idx3, + float& dis0, + float& dis1, + float& dis2, + float& dis3) { + // compute first, assign next + const float d0 = this->operator()(idx0); + const float d1 = this->operator()(idx1); + const float d2 = this->operator()(idx2); + const float d3 = this->operator()(idx3); + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; + } + /// compute distance between two stored vectors virtual float symmetric_dis(idx_t i, idx_t j) = 0; @@ -49,7 +72,7 @@ struct FlatCodesDistanceComputer : DistanceComputer { FlatCodesDistanceComputer() : codes(nullptr), code_size(0) {} - float operator()(idx_t i) final { + float operator()(idx_t i) override { return distance_to_code(codes + i * code_size); } diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index de70d05b48..cd7f9a0d91 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -14,6 +14,7 @@ #include #include #include +#include namespace faiss { @@ -563,24 +564,85 @@ int search_from_candidates( size_t begin, end; hnsw.neighbor_range(v0, level, &begin, &end); + // // baseline version + // for (size_t j = begin; j < end; j++) { + // int v1 = hnsw.neighbors[j]; + // if (v1 < 0) + // break; + // if (vt.get(v1)) { + // continue; + // } + // vt.set(v1); + // ndis++; + // float d = qdis(v1); + // if (!sel || sel->is_member(v1)) { + // if (nres < k) { + // faiss::maxheap_push(++nres, D, I, d, v1); + // } else if (d < D[0]) { + // faiss::maxheap_replace_top(nres, D, I, d, v1); + // } + // } + // candidates.push(v1, d); + // } + + // the following version processes 4 neighbors at a time + size_t jmax = begin; for (size_t j = begin; j < end; j++) { int v1 = hnsw.neighbors[j]; if (v1 < 0) break; - if (vt.get(v1)) { - continue; + + prefetch_L2(vt.visited.data() + v1); + jmax += 1; + } + + int counter = 0; + size_t saved_j[4]; + + ndis += jmax - begin; + + auto add_to_heap = [&](const size_t idx, const float dis) { + if (!sel || sel->is_member(idx)) { + if (nres < k) { + faiss::maxheap_push(++nres, D, I, dis, idx); + } else if (dis < D[0]) { + faiss::maxheap_replace_top(nres, D, I, dis, idx); + } } + candidates.push(idx, dis); + }; + + for (size_t j = begin; j < jmax; j++) { + int v1 = hnsw.neighbors[j]; + + bool vget = vt.get(v1); vt.set(v1); - ndis++; - float d = qdis(v1); - if (!sel || sel->is_member(v1)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, d, v1); - } else if (d < D[0]) { - faiss::maxheap_replace_top(nres, D, I, d, v1); + saved_j[counter] = v1; + counter += vget ? 0 : 1; + + if (counter == 4) { + float dis[4]; + qdis.distances_batch_4( + saved_j[0], + saved_j[1], + saved_j[2], + saved_j[3], + dis[0], + dis[1], + dis[2], + dis[3]); + + for (size_t id4 = 0; id4 < 4; id4++) { + add_to_heap(saved_j[id4], dis[id4]); } + + counter = 0; } - candidates.push(v1, d); + } + + for (size_t icnt = 0; icnt < counter; icnt++) { + float dis = qdis(saved_j[icnt]); + add_to_heap(saved_j[icnt], dis); } nstep++; @@ -630,29 +692,92 @@ std::priority_queue search_from_candidate_unbounded( size_t begin, end; hnsw.neighbor_range(v0, 0, &begin, &end); - for (size_t j = begin; j < end; ++j) { + // // baseline version + // for (size_t j = begin; j < end; ++j) { + // int v1 = hnsw.neighbors[j]; + // + // if (v1 < 0) { + // break; + // } + // if (vt->get(v1)) { + // continue; + // } + // + // vt->set(v1); + // + // float d1 = qdis(v1); + // ++ndis; + // + // if (top_candidates.top().first > d1 || + // top_candidates.size() < ef) { + // candidates.emplace(d1, v1); + // top_candidates.emplace(d1, v1); + // + // if (top_candidates.size() > ef) { + // top_candidates.pop(); + // } + // } + // } + + // the following version processes 4 neighbors at a time + size_t jmax = begin; + for (size_t j = begin; j < end; j++) { int v1 = hnsw.neighbors[j]; - - if (v1 < 0) { + if (v1 < 0) break; - } - if (vt->get(v1)) { - continue; - } - vt->set(v1); + prefetch_L2(vt->visited.data() + v1); + jmax += 1; + } - float d1 = qdis(v1); - ++ndis; + int counter = 0; + size_t saved_j[4]; - if (top_candidates.top().first > d1 || top_candidates.size() < ef) { - candidates.emplace(d1, v1); - top_candidates.emplace(d1, v1); + ndis += jmax - begin; + + auto add_to_heap = [&](const size_t idx, const float dis) { + if (top_candidates.top().first > dis || + top_candidates.size() < ef) { + candidates.emplace(dis, idx); + top_candidates.emplace(dis, idx); if (top_candidates.size() > ef) { top_candidates.pop(); } } + }; + + for (size_t j = begin; j < jmax; j++) { + int v1 = hnsw.neighbors[j]; + + bool vget = vt->get(v1); + vt->set(v1); + saved_j[counter] = v1; + counter += vget ? 0 : 1; + + if (counter == 4) { + float dis[4]; + qdis.distances_batch_4( + saved_j[0], + saved_j[1], + saved_j[2], + saved_j[3], + dis[0], + dis[1], + dis[2], + dis[3]); + + for (size_t id4 = 0; id4 < 4; id4++) { + add_to_heap(saved_j[id4], dis[id4]); + } + + counter = 0; + } + } + + for (size_t icnt = 0; icnt < counter; icnt++) { + float dis = qdis(saved_j[icnt]); + add_to_heap(saved_j[icnt], dis); } } diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index 04906f69fe..11b31e98e9 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -601,6 +601,8 @@ void gpu_sync_all_devices() DOWNCAST ( IndexIVFFlatDedup ) DOWNCAST ( IndexIVFFlat ) DOWNCAST ( IndexIVF ) + DOWNCAST ( IndexFlatIP ) + DOWNCAST ( IndexFlatL2 ) DOWNCAST ( IndexFlat ) DOWNCAST ( IndexRefineFlat ) DOWNCAST ( IndexRefine ) diff --git a/faiss/utils/distances.cpp b/faiss/utils/distances.cpp index 5e025f93d6..b50231b2b0 100644 --- a/faiss/utils/distances.cpp +++ b/faiss/utils/distances.cpp @@ -64,7 +64,7 @@ void fvec_norms_L2( const float* __restrict x, size_t d, size_t nx) { -#pragma omp parallel for +#pragma omp parallel for schedule(guided) for (int64_t i = 0; i < nx; i++) { nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d)); } @@ -75,13 +75,13 @@ void fvec_norms_L2sqr( const float* __restrict x, size_t d, size_t nx) { -#pragma omp parallel for +#pragma omp parallel for schedule(guided) for (int64_t i = 0; i < nx; i++) nr[i] = fvec_norm_L2sqr(x + i * d, d); } void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) { -#pragma omp parallel for +#pragma omp parallel for schedule(guided) for (int64_t i = 0; i < nx; i++) { float* __restrict xi = x + i * d; diff --git a/faiss/utils/distances.h b/faiss/utils/distances.h index fa40c5b072..a5bb2ed96d 100644 --- a/faiss/utils/distances.h +++ b/faiss/utils/distances.h @@ -36,6 +36,34 @@ float fvec_L1(const float* x, const float* y, size_t d); /// infinity distance float fvec_Linf(const float* x, const float* y, size_t d); +/// Special version of inner product that computes 4 distances +/// between x and yi, which is performance oriented. +void fvec_inner_product_batch_4( + const float* x, + const float* y0, + const float* y1, + const float* y2, + const float* y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3); + +/// Special version of L2sqr that computes 4 distances +/// between x and yi, which is performance oriented. +void fvec_L2sqr_batch_4( + const float* x, + const float* y0, + const float* y1, + const float* y2, + const float* y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3); + /** Compute pairwise distances between sets of vectors * * @param d dimension of the vectors diff --git a/faiss/utils/distances_simd.cpp b/faiss/utils/distances_simd.cpp index 66bdbcf17a..0e7a970999 100644 --- a/faiss/utils/distances_simd.cpp +++ b/faiss/utils/distances_simd.cpp @@ -223,6 +223,76 @@ float fvec_L2sqr(const float* x, const float* y, size_t d) { } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +/// Special version of inner product that computes 4 distances +/// between x and yi +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +void fvec_inner_product_batch_4( + const float* __restrict x, + const float* __restrict y0, + const float* __restrict y1, + const float* __restrict y2, + const float* __restrict y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + d0 += x[i] * y0[i]; + d1 += x[i] * y1[i]; + d2 += x[i] * y2[i]; + d3 += x[i] * y3[i]; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +/// Special version of L2sqr that computes 4 distances +/// between x and yi, which is performance oriented. +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +void fvec_L2sqr_batch_4( + const float* x, + const float* y0, + const float* y1, + const float* y2, + const float* y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + const float q0 = x[i] - y0[i]; + const float q1 = x[i] - y1[i]; + const float q2 = x[i] - y2[i]; + const float q3 = x[i] - y3[i]; + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + /********************************************************* * SSE and AVX implementations */ diff --git a/faiss/utils/prefetch.h b/faiss/utils/prefetch.h new file mode 100644 index 0000000000..9549eb3441 --- /dev/null +++ b/faiss/utils/prefetch.h @@ -0,0 +1,77 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// prefetches + +#ifdef __AVX__ + +// AVX + +#include + +inline void prefetch_L1(const void* address) { + _mm_prefetch((const char*)address, _MM_HINT_T0); +} +inline void prefetch_L2(const void* address) { + _mm_prefetch((const char*)address, _MM_HINT_T1); +} +inline void prefetch_L3(const void* address) { + _mm_prefetch((const char*)address, _MM_HINT_T2); +} + +#elif defined(__aarch64__) + +// ARM64 + +#ifdef _MSC_VER + +// todo: arm on MSVC +inline void prefetch_L1(const void* address) {} +inline void prefetch_L2(const void* address) {} +inline void prefetch_L3(const void* address) {} + +#else +// arm on non-MSVC + +inline void prefetch_L1(const void* address) { + __builtin_prefetch(address, 0, 3); +} +inline void prefetch_L2(const void* address) { + __builtin_prefetch(address, 0, 2); +} +inline void prefetch_L3(const void* address) { + __builtin_prefetch(address, 0, 1); +} +#endif + +#else + +// a generic platform + +#ifdef _MSC_VER + +inline void prefetch_L1(const void* address) {} +inline void prefetch_L2(const void* address) {} +inline void prefetch_L3(const void* address) {} + +#else + +inline void prefetch_L1(const void* address) { + __builtin_prefetch(address, 0, 3); +} +inline void prefetch_L2(const void* address) { + __builtin_prefetch(address, 0, 2); +} +inline void prefetch_L3(const void* address) { + __builtin_prefetch(address, 0, 1); +} + +#endif + +#endif diff --git a/tests/test_index.py b/tests/test_index.py index 45eba06b83..bc9392d7ea 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -110,7 +110,39 @@ def test_with_blas_reservoir_ip(self): self.do_test(200, faiss.METRIC_INNER_PRODUCT, k=150) +class TestIndexFlatL2(unittest.TestCase): + def test_indexflat_l2_sync_norms_1(self): + d = 32 + nb = 10000 + nt = 0 + nq = 16 + k = 10 + (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) + + # instantiate IndexHNSWFlat + index = faiss.IndexHNSWFlat(d, 32) + index.hnsw.efConstruction = 40 + + index.add(xb) + D1, I1 = index.search(xq, k) + + index_l2 = faiss.downcast_index(index.storage) + index_l2.sync_l2norms() + D2, I2 = index.search(xq, k) + + index_l2.clear_l2norms() + D3, I3 = index.search(xq, k) + + # not too many elements are off. + self.assertLessEqual((I2 != I1).sum(), 1) + # np.testing.assert_equal(Iref, I1) + np.testing.assert_almost_equal(D2, D1, decimal=5) + + # not too many elements are off. + self.assertLessEqual((I3 != I1).sum(), 0) + # np.testing.assert_equal(Iref, I1) + np.testing.assert_equal(D3, D1) class EvalIVFPQAccuracy(unittest.TestCase):