diff --git a/demos/demo_simd_levels.py b/demos/demo_simd_levels.py new file mode 100644 index 0000000000..842188c6d3 --- /dev/null +++ b/demos/demo_simd_levels.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import time +import faiss +import numpy as np +import os +from collections import defaultdict +from faiss.contrib.datasets import SyntheticDataset + + +print("compile options", faiss.get_compile_options()) +print("SIMD level: ", faiss.SIMDConfig.get_level_name()) + + +ds = SyntheticDataset(32, 8000, 10000, 8000) + + +index = faiss.index_factory(ds.d, "PQ16x4fs") +# index = faiss.index_factory(ds.d, "IVF64,PQ16x4fs") +# index = faiss.index_factory(ds.d, "SQ8") + +index.train(ds.get_train()) +index.add(ds.get_database()) + + +if False: + faiss.omp_set_num_threads(1) + print("PID=", os.getpid()) + input("press enter to continue") + # for simd_level in faiss.NONE, faiss.AVX2, faiss.AVX512F: + for simd_level in faiss.AVX2, faiss.AVX512F: + + faiss.SIMDConfig.set_level(simd_level) + print("simd_level=", faiss.SIMDConfig.get_level_name()) + for run in range(1000): + D, I = index.search(ds.get_queries(), 10) + +times = defaultdict(list) + +for run in range(10): + for simd_level in faiss.SIMDLevel_NONE, faiss.SIMDLevel_AVX2, faiss.SIMDLevel_AVX512F: + faiss.SIMDConfig.set_level(simd_level) + + t0 = time.time() + D, I = index.search(ds.get_queries(), 10) + t1 = time.time() + + times[faiss.SIMDConfig.get_level_name()].append(t1 - t0) + +for simd_level in times: + print(f"simd_level={simd_level} search time: {np.mean(times[simd_level])*1000:.3f} ms") diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index 91294683e4..25b0076f2c 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -185,11 +185,11 @@ set(FAISS_HEADERS impl/pq4_fast_scan.h impl/residual_quantizer_encode_steps.h impl/simd_result_handlers.h - impl/code_distance/code_distance.h - impl/code_distance/code_distance-generic.h - impl/code_distance/code_distance-avx2.h - impl/code_distance/code_distance-avx512.h - impl/code_distance/code_distance-sve.h + impl/pq_code_distance/code_distance.h + impl/pq_code_distance/code_distance-generic.h + impl/pq_code_distance/code_distance-avx2.h + impl/pq_code_distance/code_distance-avx512.h + impl/pq_code_distance/code_distance-sve.h invlists/BlockInvertedLists.h invlists/DirectMap.h invlists/InvertedLists.h diff --git a/faiss/IndexAdditiveQuantizerFastScan.cpp b/faiss/IndexAdditiveQuantizerFastScan.cpp index 58d72a38b1..b120288f54 100644 --- a/faiss/IndexAdditiveQuantizerFastScan.cpp +++ b/faiss/IndexAdditiveQuantizerFastScan.cpp @@ -7,14 +7,10 @@ #include -#include -#include - #include #include -#include #include -#include +#include #include #include @@ -199,11 +195,10 @@ void IndexAdditiveQuantizerFastScan::search( return; } - NormTableScaler scaler(norm_scale); if (metric_type == METRIC_L2) { - search_dispatch_implem(n, x, k, distances, labels, &scaler); + search_dispatch_implem(n, x, k, distances, labels, norm_scale); } else { - search_dispatch_implem(n, x, k, distances, labels, &scaler); + search_dispatch_implem(n, x, k, distances, labels, norm_scale); } } diff --git a/faiss/IndexFastScan.cpp b/faiss/IndexFastScan.cpp index b18d15bc17..68e958d08d 100644 --- a/faiss/IndexFastScan.cpp +++ b/faiss/IndexFastScan.cpp @@ -7,20 +7,16 @@ #include -#include -#include -#include - #include #include #include -#include #include +#include #include -#include -#include +#include +#include #include namespace faiss { @@ -163,7 +159,7 @@ void estimators_from_tables_generic( size_t k, typename C::T* heap_dis, int64_t* heap_ids, - const NormTableScaler* scaler) { + const Scaler2x4bit* scaler) { using accu_t = typename C::T; for (size_t j = 0; j < ncodes; ++j) { @@ -193,28 +189,6 @@ void estimators_from_tables_generic( } } -template -ResultHandlerCompare* make_knn_handler( - int impl, - idx_t n, - idx_t k, - size_t ntotal, - float* distances, - idx_t* labels, - const IDSelector* sel = nullptr) { - using HeapHC = HeapHandler; - using ReservoirHC = ReservoirHandler; - using SingleResultHC = SingleResultHandler; - - if (k == 1) { - return new SingleResultHC(n, ntotal, distances, labels, sel); - } else if (impl % 2 == 0) { - return new HeapHC(n, ntotal, k, distances, labels, sel); - } else /* if (impl % 2 == 1) */ { - return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel); - } -} - } // anonymous namespace using namespace quantize_lut; @@ -264,9 +238,9 @@ void IndexFastScan::search( FAISS_THROW_IF_NOT(k > 0); if (metric_type == METRIC_L2) { - search_dispatch_implem(n, x, k, distances, labels, nullptr); + search_dispatch_implem(n, x, k, distances, labels, -1); } else { - search_dispatch_implem(n, x, k, distances, labels, nullptr); + search_dispatch_implem(n, x, k, distances, labels, -1); } } @@ -277,7 +251,7 @@ void IndexFastScan::search_dispatch_implem( idx_t k, float* distances, idx_t* labels, - const NormTableScaler* scaler) const { + int norm_scale) const { using Cfloat = typename std::conditional< is_max, CMax, @@ -308,15 +282,17 @@ void IndexFastScan::search_dispatch_implem( FAISS_THROW_MSG("not implemented"); } else if (implem == 2 || implem == 3 || implem == 4) { FAISS_THROW_IF_NOT(orig_codes != nullptr); - search_implem_234(n, x, k, distances, labels, scaler); + search_implem_234(n, x, k, distances, labels, norm_scale); } else if (impl >= 12 && impl <= 15) { FAISS_THROW_IF_NOT(ntotal < INT_MAX); int nt = std::min(omp_get_max_threads(), int(n)); if (nt < 2) { if (impl == 12 || impl == 13) { - search_implem_12(n, x, k, distances, labels, impl, scaler); + search_implem_12( + n, x, k, distances, labels, impl, norm_scale); } else { - search_implem_14(n, x, k, distances, labels, impl, scaler); + search_implem_14( + n, x, k, distances, labels, impl, norm_scale); } } else { // explicitly slice over threads @@ -328,10 +304,22 @@ void IndexFastScan::search_dispatch_implem( idx_t* lab_i = labels + i0 * k; if (impl == 12 || impl == 13) { search_implem_12( - i1 - i0, x + i0 * d, k, dis_i, lab_i, impl, scaler); + i1 - i0, + x + i0 * d, + k, + dis_i, + lab_i, + impl, + norm_scale); } else { search_implem_14( - i1 - i0, x + i0 * d, k, dis_i, lab_i, impl, scaler); + i1 - i0, + x + i0 * d, + k, + dis_i, + lab_i, + impl, + norm_scale); } } } @@ -347,7 +335,7 @@ void IndexFastScan::search_implem_234( idx_t k, float* distances, idx_t* labels, - const NormTableScaler* scaler) const { + int norm_scale) const { FAISS_THROW_IF_NOT(implem == 2 || implem == 3 || implem == 4); const size_t dim12 = ksub * M; @@ -369,6 +357,11 @@ void IndexFastScan::search_implem_234( } } + std::unique_ptr scaler; + if (norm_scale != -1) { + scaler.reset(new Scaler2x4bit(norm_scale)); + } + #pragma omp parallel for if (n > 1000) for (int64_t i = 0; i < n; i++) { int64_t* heap_ids = labels + i * k; @@ -384,7 +377,7 @@ void IndexFastScan::search_implem_234( k, heap_dis, heap_ids, - scaler); + scaler.get()); heap_reorder(k, heap_dis, heap_ids); @@ -407,8 +400,8 @@ void IndexFastScan::search_implem_12( float* distances, idx_t* labels, int impl, - const NormTableScaler* scaler) const { - using RH = ResultHandlerCompare; + int norm_scale) const { + using RH = PQ4CodeScanner; FAISS_THROW_IF_NOT(bbs == 32); // handle qbs2 blocking by recursive call @@ -423,7 +416,7 @@ void IndexFastScan::search_implem_12( distances + i0 * k, labels + i0 * k, impl, - scaler); + norm_scale); } return; } @@ -454,22 +447,22 @@ void IndexFastScan::search_implem_12( pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get()); FAISS_THROW_IF_NOT(LUT_nq == n); - std::unique_ptr handler( - make_knn_handler(impl, n, k, ntotal, distances, labels)); - handler->disable = bool(skip & 2); - handler->normalizers = normalizers.get(); + std::unique_ptr handler(pq4_make_flat_knn_handler( + metric_type, + impl % 2 == 1, + n, + k, + ntotal, + distances, + labels, + norm_scale, + normalizers.get(), + bool(skip & 2))); if (skip & 4) { // pass } else { - pq4_accumulate_loop_qbs( - qbs, - ntotal2, - M2, - codes.get(), - LUT.get(), - *handler.get(), - scaler); + handler->accumulate_loop_qbs(qbs, ntotal2, M2, codes.get(), LUT.get()); } if (!(skip & 8)) { handler->end(); @@ -486,8 +479,8 @@ void IndexFastScan::search_implem_14( float* distances, idx_t* labels, int impl, - const NormTableScaler* scaler) const { - using RH = ResultHandlerCompare; + int norm_scale) const { + using RH = PQ4CodeScanner; FAISS_THROW_IF_NOT(bbs % 32 == 0); int qbs2 = qbs == 0 ? 4 : qbs; @@ -503,7 +496,7 @@ void IndexFastScan::search_implem_14( distances + i0 * k, labels + i0 * k, impl, - scaler); + norm_scale); } return; } @@ -522,23 +515,22 @@ void IndexFastScan::search_implem_14( AlignedTable LUT(n * dim12); pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get()); - std::unique_ptr handler( - make_knn_handler(impl, n, k, ntotal, distances, labels)); - handler->disable = bool(skip & 2); - handler->normalizers = normalizers.get(); + std::unique_ptr handler(pq4_make_flat_knn_handler( + metric_type, + impl % 2 == 1, + n, + k, + ntotal, + distances, + labels, + norm_scale, + normalizers.get(), + bool(skip & 2))); if (skip & 4) { // pass } else { - pq4_accumulate_loop( - n, - ntotal2, - bbs, - M2, - codes.get(), - LUT.get(), - *handler.get(), - scaler); + handler->accumulate_loop(n, ntotal2, bbs, M2, codes.get(), LUT.get()); } if (!(skip & 8)) { handler->end(); @@ -551,7 +543,7 @@ template void IndexFastScan::search_dispatch_implem( idx_t k, float* distances, idx_t* labels, - const NormTableScaler* scaler) const; + int norm_scale) const; template void IndexFastScan::search_dispatch_implem( idx_t n, @@ -559,7 +551,7 @@ template void IndexFastScan::search_dispatch_implem( idx_t k, float* distances, idx_t* labels, - const NormTableScaler* scaler) const; + int norm_scale) const; void IndexFastScan::reconstruct(idx_t key, float* recons) const { std::vector code(code_size, 0); diff --git a/faiss/IndexFastScan.h b/faiss/IndexFastScan.h index a0f5c592f0..5bec076807 100644 --- a/faiss/IndexFastScan.h +++ b/faiss/IndexFastScan.h @@ -13,7 +13,6 @@ namespace faiss { struct CodePacker; -struct NormTableScaler; /** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now. * @@ -95,7 +94,7 @@ struct IndexFastScan : Index { idx_t k, float* distances, idx_t* labels, - const NormTableScaler* scaler) const; + int norm_scale) const; template void search_implem_234( @@ -104,7 +103,7 @@ struct IndexFastScan : Index { idx_t k, float* distances, idx_t* labels, - const NormTableScaler* scaler) const; + int norm_scale) const; template void search_implem_12( @@ -114,7 +113,7 @@ struct IndexFastScan : Index { float* distances, idx_t* labels, int impl, - const NormTableScaler* scaler) const; + int norm_scale) const; template void search_implem_14( @@ -124,7 +123,7 @@ struct IndexFastScan : Index { float* distances, idx_t* labels, int impl, - const NormTableScaler* scaler) const; + int norm_scale) const; void reconstruct(idx_t key, float* recons) const override; size_t remove_ids(const IDSelector& sel) override; diff --git a/faiss/IndexIVFAdditiveQuantizerFastScan.cpp b/faiss/IndexIVFAdditiveQuantizerFastScan.cpp index 93fad18636..7cf7e38aa2 100644 --- a/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +++ b/faiss/IndexIVFAdditiveQuantizerFastScan.cpp @@ -8,14 +8,10 @@ #include #include -#include - -#include #include #include -#include -#include +#include #include #include #include @@ -309,9 +305,8 @@ void IndexIVFAdditiveQuantizerFastScan::search( return; } - NormTableScaler scaler(norm_scale); IndexIVFFastScan::CoarseQuantized cq{nprobe}; - search_dispatch_implem(n, x, k, distances, labels, cq, &scaler); + search_dispatch_implem(n, x, k, distances, labels, cq, norm_scale); } /********************************************************* diff --git a/faiss/IndexIVFFastScan.cpp b/faiss/IndexIVFFastScan.cpp index f031a51bba..d21082e31b 100644 --- a/faiss/IndexIVFFastScan.cpp +++ b/faiss/IndexIVFFastScan.cpp @@ -7,21 +7,14 @@ #include -#include -#include -#include -#include - #include -#include - #include #include #include -#include -#include -#include +#include +#include +#include #include #include #include @@ -212,8 +205,13 @@ void estimators_from_tables_generic( size_t k, typename C::T* heap_dis, int64_t* heap_ids, - const NormTableScaler* scaler) { + int norm_scale) { using accu_t = typename C::T; + std::unique_ptr scaler; + if (norm_scale != -1) { + scaler.reset(new Scaler2x4bit(norm_scale)); + } + size_t nscale = scaler ? scaler->nscale : 0; for (size_t j = 0; j < ncodes; ++j) { BitstringReader bsr(codes + j * index.code_size, index.code_size); @@ -345,7 +343,7 @@ void IndexIVFFastScan::search_preassigned( FAISS_THROW_IF_NOT(k > 0); const CoarseQuantized cq = {nprobe, centroid_dis, assign}; - search_dispatch_implem(n, x, k, distances, labels, cq, nullptr, params); + search_dispatch_implem(n, x, k, distances, labels, cq, -1, params); } void IndexIVFFastScan::range_search( @@ -364,49 +362,11 @@ void IndexIVFFastScan::range_search( } const CoarseQuantized cq = {nprobe, nullptr, nullptr}; - range_search_dispatch_implem(n, x, radius, *result, cq, nullptr, params); + range_search_dispatch_implem(n, x, radius, *result, cq, -1, params); } namespace { -template -ResultHandlerCompare* make_knn_handler_fixC( - int impl, - idx_t n, - idx_t k, - float* distances, - idx_t* labels, - const IDSelector* sel) { - using HeapHC = HeapHandler; - using ReservoirHC = ReservoirHandler; - using SingleResultHC = SingleResultHandler; - - if (k == 1) { - return new SingleResultHC(n, 0, distances, labels, sel); - } else if (impl % 2 == 0) { - return new HeapHC(n, 0, k, distances, labels, sel); - } else /* if (impl % 2 == 1) */ { - return new ReservoirHC(n, 0, k, 2 * k, distances, labels, sel); - } -} - -SIMDResultHandlerToFloat* make_knn_handler( - bool is_max, - int impl, - idx_t n, - idx_t k, - float* distances, - idx_t* labels, - const IDSelector* sel) { - if (is_max) { - return make_knn_handler_fixC>( - impl, n, k, distances, labels, sel); - } else { - return make_knn_handler_fixC>( - impl, n, k, distances, labels, sel); - } -} - using CoarseQuantized = IndexIVFFastScan::CoarseQuantized; struct CoarseQuantizedWithBuffer : CoarseQuantized { @@ -490,7 +450,7 @@ void IndexIVFFastScan::search_dispatch_implem( float* distances, idx_t* labels, const CoarseQuantized& cq_in, - const NormTableScaler* scaler, + int norm_scale, const IVFSearchParameters* params) const { const idx_t nprobe = params ? params->nprobe : this->nprobe; const IDSelector* sel = (params) ? params->sel : nullptr; @@ -498,7 +458,7 @@ void IndexIVFFastScan::search_dispatch_implem( params ? params->quantizer_params : nullptr; bool is_max = !is_similarity_metric(metric_type); - using RH = SIMDResultHandlerToFloat; + using RH = PQ4CodeScanner; if (n == 0) { return; @@ -539,59 +499,73 @@ void IndexIVFFastScan::search_dispatch_implem( if (impl == 1) { if (is_max) { search_implem_1>( - n, x, k, distances, labels, cq, scaler, params); + n, x, k, distances, labels, cq, norm_scale, params); } else { search_implem_1>( - n, x, k, distances, labels, cq, scaler, params); + n, x, k, distances, labels, cq, norm_scale, params); } } else if (impl == 2) { if (is_max) { search_implem_2>( - n, x, k, distances, labels, cq, scaler, params); + n, x, k, distances, labels, cq, norm_scale, params); } else { search_implem_2>( - n, x, k, distances, labels, cq, scaler, params); + n, x, k, distances, labels, cq, norm_scale, params); } } else if (impl >= 10 && impl <= 15) { size_t ndis = 0, nlist_visited = 0; if (!multiple_threads) { - // clang-format off if (impl == 12 || impl == 13) { - std::unique_ptr handler( - make_knn_handler( - is_max, - impl, - n, - k, - distances, - labels, sel - ) - ); + std::unique_ptr handler(pq4_make_ivf_knn_handler( + is_max, + impl, + n, + k, + distances, + labels, + norm_scale, + sel)); search_implem_12( - n, x, *handler.get(), - cq, &ndis, &nlist_visited, scaler, params); + n, + x, + *handler.get(), + cq, + &ndis, + &nlist_visited, + norm_scale, + params); } else if (impl == 14 || impl == 15) { search_implem_14( - n, x, k, distances, labels, - cq, impl, scaler, params); + n, + x, + k, + distances, + labels, + cq, + impl, + norm_scale, + params); } else { - std::unique_ptr handler( - make_knn_handler( - is_max, - impl, - n, - k, - distances, + std::unique_ptr handler(pq4_make_ivf_knn_handler( + is_max, + impl, + n, + k, + distances, labels, - sel - ) - ); + norm_scale, + sel)); search_implem_10( - n, x, *handler.get(), cq, - &ndis, &nlist_visited, scaler, params); + n, + x, + *handler.get(), + cq, + &ndis, + &nlist_visited, + norm_scale, + params); } - // clang-format on } else { // explicitly slice over threads int nslice = compute_search_nslice(this, n, cq.nprobe); @@ -599,7 +573,15 @@ void IndexIVFFastScan::search_dispatch_implem( // this might require slicing if there are too // many queries (for now we keep this simple) search_implem_14( - n, x, k, distances, labels, cq, impl, scaler, params); + n, + x, + k, + distances, + labels, + cq, + impl, + norm_scale, + params); } else { #pragma omp parallel for reduction(+ : ndis, nlist_visited) for (int slice = 0; slice < nslice; slice++) { @@ -611,17 +593,24 @@ void IndexIVFFastScan::search_dispatch_implem( if (!cq_i.done()) { cq_i.quantize_slice(quantizer, x, quantizer_params); } - std::unique_ptr handler(make_knn_handler( - is_max, impl, i1 - i0, k, dis_i, lab_i, sel)); + std::unique_ptr handler(pq4_make_ivf_knn_handler( + is_max, + impl, + i1 - i0, + k, + dis_i, + lab_i, + norm_scale, + sel)); // clang-format off if (impl == 12 || impl == 13) { search_implem_12( i1 - i0, x + i0 * d, *handler.get(), - cq_i, &ndis, &nlist_visited, scaler, params); + cq_i, &ndis, &nlist_visited, norm_scale, params); } else { search_implem_10( i1 - i0, x + i0 * d, *handler.get(), - cq_i, &ndis, &nlist_visited, scaler, params); + cq_i, &ndis, &nlist_visited, norm_scale, params); } // clang-format on } @@ -641,7 +630,7 @@ void IndexIVFFastScan::range_search_dispatch_implem( float radius, RangeSearchResult& rres, const CoarseQuantized& cq_in, - const NormTableScaler* scaler, + int norm_scale, const IVFSearchParameters* params) const { // const idx_t nprobe = params ? params->nprobe : this->nprobe; const IDSelector* sel = (params) ? params->sel : nullptr; @@ -682,20 +671,26 @@ void IndexIVFFastScan::range_search_dispatch_implem( size_t ndis = 0, nlist_visited = 0; if (!multiple_threads) { // single thread - std::unique_ptr handler; - if (is_max) { - handler.reset(new RangeHandler, true>( - rres, radius, 0, sel)); - } else { - handler.reset(new RangeHandler, true>( - rres, radius, 0, sel)); - } + std::unique_ptr handler(pq4_make_ivf_range_handler( + is_max, rres, radius, norm_scale, sel)); if (impl == 12) { search_implem_12( - n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler); + n, + x, + *handler.get(), + cq, + &ndis, + &nlist_visited, + norm_scale); } else if (impl == 10) { search_implem_10( - n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler); + n, + x, + *handler.get(), + cq, + &ndis, + &nlist_visited, + norm_scale); } else { FAISS_THROW_FMT("Range search implem %d not implemented", impl); } @@ -714,16 +709,9 @@ void IndexIVFFastScan::range_search_dispatch_implem( if (!cq_i.done()) { cq_i.quantize_slice(quantizer, x, quantizer_params); } - std::unique_ptr handler; - if (is_max) { - handler.reset(new PartialRangeHandler< - CMax, - true>(pres, radius, 0, i0, i1, sel)); - } else { - handler.reset(new PartialRangeHandler< - CMin, - true>(pres, radius, 0, i0, i1, sel)); - } + std::unique_ptr handler( + pq4_make_ivf_partial_range_handler( + is_max, pres, radius, i0, i1, norm_scale, sel)); if (impl == 12 || impl == 13) { search_implem_12( @@ -733,7 +721,7 @@ void IndexIVFFastScan::range_search_dispatch_implem( cq_i, &ndis, &nlist_visited, - scaler, + norm_scale, params); } else { search_implem_10( @@ -743,7 +731,7 @@ void IndexIVFFastScan::range_search_dispatch_implem( cq_i, &ndis, &nlist_visited, - scaler, + norm_scale, params); } } @@ -764,7 +752,7 @@ void IndexIVFFastScan::search_implem_1( float* distances, idx_t* labels, const CoarseQuantized& cq, - const NormTableScaler* scaler, + int norm_scale, const IVFSearchParameters* params) const { FAISS_THROW_IF_NOT(orig_invlists); @@ -813,7 +801,7 @@ void IndexIVFFastScan::search_implem_1( k, heap_dis, heap_ids, - scaler); + norm_scale); nlist_visited++; ndis++; } @@ -832,7 +820,7 @@ void IndexIVFFastScan::search_implem_2( float* distances, idx_t* labels, const CoarseQuantized& cq, - const NormTableScaler* scaler, + int norm_scale, const IVFSearchParameters* params) const { FAISS_THROW_IF_NOT(orig_invlists); @@ -884,7 +872,7 @@ void IndexIVFFastScan::search_implem_2( k, heap_dis, heap_ids, - scaler); + norm_scale); nlist_visited++; ndis += ls; @@ -911,11 +899,11 @@ void IndexIVFFastScan::search_implem_2( void IndexIVFFastScan::search_implem_10( idx_t n, const float* x, - SIMDResultHandlerToFloat& handler, + PQ4CodeScanner& handler, const CoarseQuantized& cq, size_t* ndis_out, size_t* nlist_out, - const NormTableScaler* scaler, + int norm_scale, const IVFSearchParameters* params) const { size_t dim12 = ksub * M2; AlignedTable dis_tables; @@ -964,15 +952,8 @@ void IndexIVFFastScan::search_implem_10( handler.ntotal = ls; handler.id_map = ids.get(); - pq4_accumulate_loop( - 1, - roundup(ls, bbs), - bbs, - M2, - codes.get(), - LUT, - handler, - scaler); + handler.accumulate_loop( + 1, roundup(ls, bbs), bbs, M2, codes.get(), LUT); ndis++; } @@ -986,11 +967,11 @@ void IndexIVFFastScan::search_implem_10( void IndexIVFFastScan::search_implem_12( idx_t n, const float* x, - SIMDResultHandlerToFloat& handler, + PQ4CodeScanner& handler, const CoarseQuantized& cq, size_t* ndis_out, size_t* nlist_out, - const NormTableScaler* scaler, + int norm_scale, const IVFSearchParameters* params) const { if (n == 0) { // does not work well with reservoir return; @@ -1100,14 +1081,8 @@ void IndexIVFFastScan::search_implem_12( handler.q_map = q_map.data(); handler.id_map = ids.get(); - pq4_accumulate_loop_qbs( - qbs_for_list, - list_size, - M2, - codes.get(), - LUT.get(), - handler, - scaler); + handler.accumulate_loop_qbs( + qbs_for_list, list_size, M2, codes.get(), LUT.get()); // prepare for next loop i0 = i1; } @@ -1131,7 +1106,7 @@ void IndexIVFFastScan::search_implem_14( idx_t* labels, const CoarseQuantized& cq, int impl, - const NormTableScaler* scaler, + int norm_scale, const IVFSearchParameters* params) const { if (n == 0) { // does not work well with reservoir return; @@ -1241,8 +1216,15 @@ void IndexIVFFastScan::search_implem_14( std::vector local_dis(k * n); // prepare the result handlers - std::unique_ptr handler(make_knn_handler( - is_max, impl, n, k, local_dis.data(), local_idx.data(), sel)); + std::unique_ptr handler(pq4_make_ivf_knn_handler( + is_max, + impl, + n, + k, + local_dis.data(), + local_idx.data(), + norm_scale, + sel)); handler->begin(normalizers.get()); int actual_qbs2 = this->qbs2 ? this->qbs2 : 11; @@ -1301,14 +1283,8 @@ void IndexIVFFastScan::search_implem_14( handler->q_map = q_map.data(); handler->id_map = ids.get(); - pq4_accumulate_loop_qbs( - qbs_for_list, - list_size, - M2, - codes.get(), - LUT.get(), - *handler.get(), - scaler); + handler->accumulate_loop_qbs( + qbs_for_list, list_size, M2, codes.get(), LUT.get()); } // labels is in-place for HeapHC diff --git a/faiss/IndexIVFFastScan.h b/faiss/IndexIVFFastScan.h index 48d6dafa1e..bcb507e957 100644 --- a/faiss/IndexIVFFastScan.h +++ b/faiss/IndexIVFFastScan.h @@ -15,7 +15,7 @@ namespace faiss { struct NormTableScaler; -struct SIMDResultHandlerToFloat; +struct PQ4CodeScanner; struct Quantizer; /** Fast scan version of IVFPQ and IVFAQ. Works for 4-bit PQ/AQ for now. @@ -154,7 +154,7 @@ struct IndexIVFFastScan : IndexIVF { float* distances, idx_t* labels, const CoarseQuantized& cq, - const NormTableScaler* scaler, + int norm_scale, const IVFSearchParameters* params = nullptr) const; void range_search_dispatch_implem( @@ -163,7 +163,7 @@ struct IndexIVFFastScan : IndexIVF { float radius, RangeSearchResult& rres, const CoarseQuantized& cq_in, - const NormTableScaler* scaler, + int norm_scale, const IVFSearchParameters* params = nullptr) const; // impl 1 and 2 are just for verification @@ -175,7 +175,7 @@ struct IndexIVFFastScan : IndexIVF { float* distances, idx_t* labels, const CoarseQuantized& cq, - const NormTableScaler* scaler, + int norm_scale, const IVFSearchParameters* params = nullptr) const; template @@ -186,7 +186,7 @@ struct IndexIVFFastScan : IndexIVF { float* distances, idx_t* labels, const CoarseQuantized& cq, - const NormTableScaler* scaler, + int norm_scale, const IVFSearchParameters* params = nullptr) const; // implem 10 and 12 are not multithreaded internally, so @@ -194,21 +194,21 @@ struct IndexIVFFastScan : IndexIVF { void search_implem_10( idx_t n, const float* x, - SIMDResultHandlerToFloat& handler, + PQ4CodeScanner& handler, const CoarseQuantized& cq, size_t* ndis_out, size_t* nlist_out, - const NormTableScaler* scaler, + int norm_scale, const IVFSearchParameters* params = nullptr) const; void search_implem_12( idx_t n, const float* x, - SIMDResultHandlerToFloat& handler, + PQ4CodeScanner& handler, const CoarseQuantized& cq, size_t* ndis_out, size_t* nlist_out, - const NormTableScaler* scaler, + int norm_scale, const IVFSearchParameters* params = nullptr) const; // implem 14 is multithreaded internally across nprobes and queries @@ -220,7 +220,7 @@ struct IndexIVFFastScan : IndexIVF { idx_t* labels, const CoarseQuantized& cq, int impl, - const NormTableScaler* scaler, + int norm_scale, const IVFSearchParameters* params = nullptr) const; // reconstruct vectors from packed invlists diff --git a/faiss/IndexIVFPQ.cpp b/faiss/IndexIVFPQ.cpp index ce1f5ddf1a..f20c7deb92 100644 --- a/faiss/IndexIVFPQ.cpp +++ b/faiss/IndexIVFPQ.cpp @@ -32,7 +32,7 @@ #include -#include +#include namespace faiss { @@ -805,8 +805,9 @@ struct RangeSearchResults { * The scanning functions call their favorite precompute_* * function to precompute the tables they need. *****************************************************/ -template +template struct IVFPQScannerT : QueryTables { + using PQDecoder = typename PQCodeDistance::PQDecoder; const uint8_t* list_codes; const IDType* list_ids; size_t list_size; @@ -882,7 +883,7 @@ struct IVFPQScannerT : QueryTables { float distance_1 = 0; float distance_2 = 0; float distance_3 = 0; - distance_four_codes( + PQCodeDistance::distance_four_codes( pq.M, pq.nbits, sim_table, @@ -905,7 +906,7 @@ struct IVFPQScannerT : QueryTables { if (counter >= 1) { float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -914,7 +915,7 @@ struct IVFPQScannerT : QueryTables { } if (counter >= 2) { float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -923,7 +924,7 @@ struct IVFPQScannerT : QueryTables { } if (counter >= 3) { float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -1089,7 +1090,7 @@ struct IVFPQScannerT : QueryTables { float distance_1 = dis0; float distance_2 = dis0; float distance_3 = dis0; - distance_four_codes( + PQCodeDistance::distance_four_codes( pq.M, pq.nbits, sim_table, @@ -1120,7 +1121,7 @@ struct IVFPQScannerT : QueryTables { n_hamming_pass++; float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -1140,7 +1141,7 @@ struct IVFPQScannerT : QueryTables { n_hamming_pass++; float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -1185,8 +1186,8 @@ struct IVFPQScannerT : QueryTables { * * use_sel: store or ignore the IDSelector */ -template -struct IVFPQScanner : IVFPQScannerT, +template +struct IVFPQScanner : IVFPQScannerT, InvertedListScanner { int precompute_mode; const IDSelector* sel; @@ -1196,7 +1197,7 @@ struct IVFPQScanner : IVFPQScannerT, bool store_pairs, int precompute_mode, const IDSelector* sel) - : IVFPQScannerT(ivfpq, nullptr), + : IVFPQScannerT(ivfpq, nullptr), precompute_mode(precompute_mode), sel(sel) { this->store_pairs = store_pairs; @@ -1215,7 +1216,7 @@ struct IVFPQScanner : IVFPQScannerT, float distance_to_code(const uint8_t* code) const override { assert(precompute_mode == 2); float dis = this->dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( this->pq.M, this->pq.nbits, this->sim_table, code); return dis; } @@ -1279,7 +1280,9 @@ struct IVFPQScanner : IVFPQScannerT, } }; -template +/** follow 3 stages of template dispatching */ + +template InvertedListScanner* get_InvertedListScanner1( const IndexIVFPQ& index, bool store_pairs, @@ -1288,32 +1291,47 @@ InvertedListScanner* get_InvertedListScanner1( return new IVFPQScanner< METRIC_INNER_PRODUCT, CMin, - PQDecoder, + PQCodeDistance, use_sel>(index, store_pairs, 2, sel); } else if (index.metric_type == METRIC_L2) { return new IVFPQScanner< METRIC_L2, CMax, - PQDecoder, + PQCodeDistance, use_sel>(index, store_pairs, 2, sel); } return nullptr; } -template +template InvertedListScanner* get_InvertedListScanner2( const IndexIVFPQ& index, bool store_pairs, const IDSelector* sel) { if (index.pq.nbits == 8) { - return get_InvertedListScanner1( - index, store_pairs, sel); + return get_InvertedListScanner1< + PQCodeDistance, + use_sel>(index, store_pairs, sel); } else if (index.pq.nbits == 16) { - return get_InvertedListScanner1( - index, store_pairs, sel); + return get_InvertedListScanner1< + PQCodeDistance, + use_sel>(index, store_pairs, sel); + } else { + return get_InvertedListScanner1< + PQCodeDistance, + use_sel>(index, store_pairs, sel); + } +} + +template +InvertedListScanner* get_InvertedListScanner3( + const IndexIVFPQ& index, + bool store_pairs, + const IDSelector* sel) { + if (sel) { + return get_InvertedListScanner2(index, store_pairs, sel); } else { - return get_InvertedListScanner1( - index, store_pairs, sel); + return get_InvertedListScanner2(index, store_pairs, sel); } } @@ -1323,11 +1341,7 @@ InvertedListScanner* IndexIVFPQ::get_InvertedListScanner( bool store_pairs, const IDSelector* sel, const IVFSearchParameters*) const { - if (sel) { - return get_InvertedListScanner2(*this, store_pairs, sel); - } else { - return get_InvertedListScanner2(*this, store_pairs, sel); - } + DISPATCH_SIMDLevel(get_InvertedListScanner3, *this, store_pairs, sel); return nullptr; } diff --git a/faiss/IndexIVFPQFastScan.cpp b/faiss/IndexIVFPQFastScan.cpp index 95efaaaf89..f79d71d094 100644 --- a/faiss/IndexIVFPQFastScan.cpp +++ b/faiss/IndexIVFPQFastScan.cpp @@ -8,10 +8,6 @@ #include #include -#include -#include - -#include #include #include @@ -20,8 +16,8 @@ #include -#include -#include +#include +#include namespace faiss { diff --git a/faiss/IndexPQ.cpp b/faiss/IndexPQ.cpp index 8193e78b17..7bdb137c7d 100644 --- a/faiss/IndexPQ.cpp +++ b/faiss/IndexPQ.cpp @@ -20,7 +20,7 @@ #include #include -#include +#include namespace faiss { @@ -73,7 +73,7 @@ void IndexPQ::train(idx_t n, const float* x) { namespace { -template +template struct PQDistanceComputer : FlatCodesDistanceComputer { size_t d; MetricType metric; @@ -86,7 +86,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer { float distance_to_code(const uint8_t* code) final { ndis++; - float dis = distance_single_code( + float dis = PQCodeDistance::distance_single_code( pq.M, pq.nbits, precomputed_table.data(), code); return dis; } @@ -95,8 +95,10 @@ struct PQDistanceComputer : FlatCodesDistanceComputer { FAISS_THROW_IF_NOT(sdc); const float* sdci = sdc; float accu = 0; - PQDecoder codei(codes + i * code_size, pq.nbits); - PQDecoder codej(codes + j * code_size, pq.nbits); + typename PQCodeDistance::PQDecoder codei( + codes + i * code_size, pq.nbits); + typename PQCodeDistance::PQDecoder codej( + codes + j * code_size, pq.nbits); for (int l = 0; l < pq.M; l++) { accu += sdci[codei.decode() + (codej.decode() << codei.nbits)]; @@ -132,16 +134,24 @@ struct PQDistanceComputer : FlatCodesDistanceComputer { } }; +template +FlatCodesDistanceComputer* get_FlatCodesDistanceComputer1( + const IndexPQ& index) { + int nbits = index.pq.nbits; + if (nbits == 8) { + return new PQDistanceComputer>(index); + } else if (nbits == 16) { + return new PQDistanceComputer>(index); + } else { + return new PQDistanceComputer>( + index); + } +} + } // namespace FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const { - if (pq.nbits == 8) { - return new PQDistanceComputer(*this); - } else if (pq.nbits == 16) { - return new PQDistanceComputer(*this); - } else { - return new PQDistanceComputer(*this); - } + DISPATCH_SIMDLevel(get_FlatCodesDistanceComputer1, *this); } /***************************************** diff --git a/faiss/IndexPQFastScan.cpp b/faiss/IndexPQFastScan.cpp index 153a881bde..b2b2ab0f3f 100644 --- a/faiss/IndexPQFastScan.cpp +++ b/faiss/IndexPQFastScan.cpp @@ -11,7 +11,7 @@ #include #include -#include +#include #include namespace faiss { diff --git a/faiss/clone_index.cpp b/faiss/clone_index.cpp index 5a1e5cfad2..60468da321 100644 --- a/faiss/clone_index.cpp +++ b/faiss/clone_index.cpp @@ -47,7 +47,7 @@ #include #include #include -#include +#include #include diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h index ce15c63510..b5544e0f84 100644 --- a/faiss/impl/ResultHandler.h +++ b/faiss/impl/ResultHandler.h @@ -79,6 +79,33 @@ struct ResultHandler { virtual ~ResultHandler() {} }; +/***************************************************************** + * Common ancestor for top-k search results. + *****************************************************************/ + +template +struct TopkBlockResultHandler : BlockResultHandler { + using T = typename C::T; + using TI = typename C::TI; + T* dis_tab; + TI* ids_tab; + + int64_t k; // number of results to keep + + TopkBlockResultHandler( + size_t nq, + T* dis_tab, + TI* ids_tab, + size_t k, + const IDSelector* sel = nullptr) + : BlockResultHandler(nq, sel), + dis_tab(dis_tab), + ids_tab(ids_tab), + k(k) {} + + ~TopkBlockResultHandler() {} +}; + /***************************************************************** * Single best result handler. * Tracks the only best result, thus avoiding storing @@ -86,25 +113,19 @@ struct ResultHandler { *****************************************************************/ template -struct Top1BlockResultHandler : BlockResultHandler { +struct Top1BlockResultHandler : TopkBlockResultHandler { using T = typename C::T; using TI = typename C::TI; using BlockResultHandler::i0; using BlockResultHandler::i1; - // contains exactly nq elements - T* dis_tab; - // contains exactly nq elements - TI* ids_tab; - Top1BlockResultHandler( size_t nq, T* dis_tab, TI* ids_tab, const IDSelector* sel = nullptr) - : BlockResultHandler(nq, sel), - dis_tab(dis_tab), - ids_tab(ids_tab) {} + : TopkBlockResultHandler(nq, dis_tab, ids_tab, 1, sel) { + } struct SingleResultHandler : ResultHandler { Top1BlockResultHandler& hr; @@ -184,28 +205,21 @@ struct Top1BlockResultHandler : BlockResultHandler { *****************************************************************/ template -struct HeapBlockResultHandler : BlockResultHandler { +struct HeapBlockResultHandler : TopkBlockResultHandler { using T = typename C::T; using TI = typename C::TI; using BlockResultHandler::i0; using BlockResultHandler::i1; - - T* heap_dis_tab; - TI* heap_ids_tab; - - int64_t k; // number of results to keep + using TopkBlockResultHandler::k; HeapBlockResultHandler( size_t nq, - T* heap_dis_tab, - TI* heap_ids_tab, + T* dis_tab, + TI* ids_tab, size_t k, const IDSelector* sel = nullptr) - : BlockResultHandler(nq, sel), - heap_dis_tab(heap_dis_tab), - heap_ids_tab(heap_ids_tab), - k(k) {} - + : TopkBlockResultHandler(nq, dis_tab, ids_tab, k, sel) { + } /****************************************************** * API for 1 result at a time (each SingleResultHandler is * called from 1 thread) @@ -224,8 +238,8 @@ struct HeapBlockResultHandler : BlockResultHandler { /// begin results for query # i void begin(size_t i) { - heap_dis = hr.heap_dis_tab + i * k; - heap_ids = hr.heap_ids_tab + i * k; + heap_dis = hr.dis_tab + i * k; + heap_ids = hr.ids_tab + i * k; heap_heapify(k, heap_dis, heap_ids); threshold = heap_dis[0]; } @@ -255,7 +269,8 @@ struct HeapBlockResultHandler : BlockResultHandler { this->i0 = i0_2; this->i1 = i1_2; for (size_t i = i0; i < i1; i++) { - heap_heapify(k, heap_dis_tab + i * k, heap_ids_tab + i * k); + heap_heapify( + k, this->dis_tab + i * this->k, this->ids_tab + i * k); } } @@ -263,8 +278,8 @@ struct HeapBlockResultHandler : BlockResultHandler { void add_results(size_t j0, size_t j1, const T* dis_tab) final { #pragma omp parallel for for (int64_t i = i0; i < i1; i++) { - T* heap_dis = heap_dis_tab + i * k; - TI* heap_ids = heap_ids_tab + i * k; + T* heap_dis = this->dis_tab + i * k; + TI* heap_ids = this->ids_tab + i * k; const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0; T thresh = heap_dis[0]; for (size_t j = j0; j < j1; j++) { @@ -281,7 +296,7 @@ struct HeapBlockResultHandler : BlockResultHandler { void end_multiple() final { // maybe parallel for for (size_t i = i0; i < i1; i++) { - heap_reorder(k, heap_dis_tab + i * k, heap_ids_tab + i * k); + heap_reorder(k, this->dis_tab + i * k, this->ids_tab + i * k); } } }; @@ -290,9 +305,9 @@ struct HeapBlockResultHandler : BlockResultHandler { * Reservoir result handler * * A reservoir is a result array of size capacity > n (number of requested - * results) all results below a threshold are stored in an arbitrary order. When - * the capacity is reached, a new threshold is chosen by partitionning the - * distance array. + * results) all results below a threshold are stored in an arbitrary order. + *When the capacity is reached, a new threshold is chosen by partitionning + *the distance array. *****************************************************************/ /// Reservoir for a single query @@ -367,28 +382,21 @@ struct ReservoirTopN : ResultHandler { }; template -struct ReservoirBlockResultHandler : BlockResultHandler { +struct ReservoirBlockResultHandler : TopkBlockResultHandler { using T = typename C::T; using TI = typename C::TI; using BlockResultHandler::i0; using BlockResultHandler::i1; - T* heap_dis_tab; - TI* heap_ids_tab; - - int64_t k; // number of results to keep size_t capacity; // capacity of the reservoirs ReservoirBlockResultHandler( size_t nq, - T* heap_dis_tab, - TI* heap_ids_tab, + T* dis_tab, + TI* ids_tab, size_t k, const IDSelector* sel = nullptr) - : BlockResultHandler(nq, sel), - heap_dis_tab(heap_dis_tab), - heap_ids_tab(heap_ids_tab), - k(k) { + : TopkBlockResultHandler(nq, dis_tab, ids_tab, k, sel) { // double then round up to multiple of 16 (for SIMD alignment) capacity = (2 * k + 15) & ~15; } @@ -423,8 +431,8 @@ struct ReservoirBlockResultHandler : BlockResultHandler { /// series of results for query qno is done void end() { - T* heap_dis = hr.heap_dis_tab + qno * hr.k; - TI* heap_ids = hr.heap_ids_tab + qno * hr.k; + T* heap_dis = hr.dis_tab + qno * hr.k; + TI* heap_ids = hr.ids_tab + qno * hr.k; this->to_result(heap_dis, heap_ids); } }; @@ -446,7 +454,7 @@ struct ReservoirBlockResultHandler : BlockResultHandler { reservoirs.clear(); for (size_t i = i0_2; i < i1_2; i++) { reservoirs.emplace_back( - k, + this->k, capacity, reservoir_dis.data() + (i - i0_2) * capacity, reservoir_ids.data() + (i - i0_2) * capacity); @@ -471,7 +479,7 @@ struct ReservoirBlockResultHandler : BlockResultHandler { // maybe parallel for for (size_t i = i0; i < i1; i++) { reservoirs[i - i0].to_result( - heap_dis_tab + i * k, heap_ids_tab + i * k); + this->dis_tab + i * this->k, this->ids_tab + i * this->k); } } }; @@ -535,7 +543,8 @@ struct RangeSearchBlockResultHandler : BlockResultHandler { // finalize the partial result pres.finalize(); } catch ([[maybe_unused]] const faiss::FaissException& e) { - // Do nothing if allocation fails in finalizing partial results. + // Do nothing if allocation fails in finalizing partial + // results. #ifndef NDEBUG std::cerr << e.what() << std::endl; #endif diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index e05f3a1f25..f75ba1d7ec 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -5,1965 +5,71 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - -#include - -#include -#include - -#include -#include - -#ifdef __SSE__ -#include -#endif - -#include -#include -#include -#include -#include -#include -#include - -namespace faiss { - -/******************************************************************* - * ScalarQuantizer implementation - * - * The main source of complexity is to support combinations of 4 - * variants without incurring runtime tests or virtual function calls: - * - * - 4 / 8 bits per code component - * - uniform / non-uniform - * - IP / L2 distance search - * - scalar / AVX distance computation - * - * The appropriate Quantizer object is returned via select_quantizer - * that hides the template mess. - ********************************************************************/ - -#if defined(__AVX512F__) && defined(__F16C__) -#define USE_AVX512_F16C -#elif defined(__AVX2__) -#ifdef __F16C__ -#define USE_F16C -#else -#warning \ - "Cannot enable AVX optimizations in scalar quantizer if -mf16c is not set as well" -#endif -#endif - -#if defined(__aarch64__) -#if defined(__GNUC__) && __GNUC__ < 8 -#warning \ - "Cannot enable NEON optimizations in scalar quantizer if the compiler is GCC<8" -#else -#define USE_NEON -#endif -#endif - -namespace { - -typedef ScalarQuantizer::QuantizerType QuantizerType; -typedef ScalarQuantizer::RangeStat RangeStat; -using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer; - -/******************************************************************* - * Codec: converts between values in [0, 1] and an index in a code - * array. The "i" parameter is the vector component index (not byte - * index). - */ - -struct Codec8bit { - static FAISS_ALWAYS_INLINE void encode_component( - float x, - uint8_t* code, - int i) { - code[i] = (int)(255 * x); - } - - static FAISS_ALWAYS_INLINE float decode_component( - const uint8_t* code, - int i) { - return (code[i] + 0.5f) / 255.0f; - } - -#if defined(__AVX512F__) - static FAISS_ALWAYS_INLINE __m512 - decode_16_components(const uint8_t* code, int i) { - const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i)); - const __m512i i32 = _mm512_cvtepu8_epi32(c16); - const __m512 f16 = _mm512_cvtepi32_ps(i32); - const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f); - const __m512 one_255 = _mm512_set1_ps(1.f / 255.f); - return _mm512_fmadd_ps(f16, one_255, half_one_255); - } -#elif defined(__AVX2__) - static FAISS_ALWAYS_INLINE __m256 - decode_8_components(const uint8_t* code, int i) { - const uint64_t c8 = *(uint64_t*)(code + i); - - const __m128i i8 = _mm_set1_epi64x(c8); - const __m256i i32 = _mm256_cvtepu8_epi32(i8); - const __m256 f8 = _mm256_cvtepi32_ps(i32); - const __m256 half_one_255 = _mm256_set1_ps(0.5f / 255.f); - const __m256 one_255 = _mm256_set1_ps(1.f / 255.f); - return _mm256_fmadd_ps(f8, one_255, half_one_255); - } -#endif - -#ifdef USE_NEON - static FAISS_ALWAYS_INLINE float32x4x2_t - decode_8_components(const uint8_t* code, int i) { - float32_t result[8] = {}; - for (size_t j = 0; j < 8; j++) { - result[j] = decode_component(code, i + j); - } - float32x4_t res1 = vld1q_f32(result); - float32x4_t res2 = vld1q_f32(result + 4); - return {res1, res2}; - } -#endif -}; - -struct Codec4bit { - static FAISS_ALWAYS_INLINE void encode_component( - float x, - uint8_t* code, - int i) { - code[i / 2] |= (int)(x * 15.0) << ((i & 1) << 2); - } - - static FAISS_ALWAYS_INLINE float decode_component( - const uint8_t* code, - int i) { - return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f; - } - -#if defined(__AVX512F__) - static FAISS_ALWAYS_INLINE __m512 - decode_16_components(const uint8_t* code, int i) { - uint64_t c8 = *(uint64_t*)(code + (i >> 1)); - uint64_t mask = 0x0f0f0f0f0f0f0f0f; - uint64_t c8ev = c8 & mask; - uint64_t c8od = (c8 >> 4) & mask; - - __m128i c16 = - _mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od)); - __m256i c8lo = _mm256_cvtepu8_epi32(c16); - __m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8)); - __m512i i16 = _mm512_castsi256_si512(c8lo); - i16 = _mm512_inserti32x8(i16, c8hi, 1); - __m512 f16 = _mm512_cvtepi32_ps(i16); - const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f); - const __m512 one_255 = _mm512_set1_ps(1.f / 15.f); - return _mm512_fmadd_ps(f16, one_255, half_one_255); - } -#elif defined(__AVX2__) - static FAISS_ALWAYS_INLINE __m256 - decode_8_components(const uint8_t* code, int i) { - uint32_t c4 = *(uint32_t*)(code + (i >> 1)); - uint32_t mask = 0x0f0f0f0f; - uint32_t c4ev = c4 & mask; - uint32_t c4od = (c4 >> 4) & mask; - - // the 8 lower bytes of c8 contain the values - __m128i c8 = - _mm_unpacklo_epi8(_mm_set1_epi32(c4ev), _mm_set1_epi32(c4od)); - __m128i c4lo = _mm_cvtepu8_epi32(c8); - __m128i c4hi = _mm_cvtepu8_epi32(_mm_srli_si128(c8, 4)); - __m256i i8 = _mm256_castsi128_si256(c4lo); - i8 = _mm256_insertf128_si256(i8, c4hi, 1); - __m256 f8 = _mm256_cvtepi32_ps(i8); - __m256 half = _mm256_set1_ps(0.5f); - f8 = _mm256_add_ps(f8, half); - __m256 one_255 = _mm256_set1_ps(1.f / 15.f); - return _mm256_mul_ps(f8, one_255); - } -#endif - -#ifdef USE_NEON - static FAISS_ALWAYS_INLINE float32x4x2_t - decode_8_components(const uint8_t* code, int i) { - float32_t result[8] = {}; - for (size_t j = 0; j < 8; j++) { - result[j] = decode_component(code, i + j); - } - float32x4_t res1 = vld1q_f32(result); - float32x4_t res2 = vld1q_f32(result + 4); - return {res1, res2}; - } -#endif -}; - -struct Codec6bit { - static FAISS_ALWAYS_INLINE void encode_component( - float x, - uint8_t* code, - int i) { - int bits = (int)(x * 63.0); - code += (i >> 2) * 3; - switch (i & 3) { - case 0: - code[0] |= bits; - break; - case 1: - code[0] |= bits << 6; - code[1] |= bits >> 2; - break; - case 2: - code[1] |= bits << 4; - code[2] |= bits >> 4; - break; - case 3: - code[2] |= bits << 2; - break; - } - } - - static FAISS_ALWAYS_INLINE float decode_component( - const uint8_t* code, - int i) { - uint8_t bits; - code += (i >> 2) * 3; - switch (i & 3) { - case 0: - bits = code[0] & 0x3f; - break; - case 1: - bits = code[0] >> 6; - bits |= (code[1] & 0xf) << 2; - break; - case 2: - bits = code[1] >> 4; - bits |= (code[2] & 3) << 4; - break; - case 3: - bits = code[2] >> 2; - break; - } - return (bits + 0.5f) / 63.0f; - } - -#if defined(__AVX512F__) - - static FAISS_ALWAYS_INLINE __m512 - decode_16_components(const uint8_t* code, int i) { - // pure AVX512 implementation (not necessarily the fastest). - // see: - // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h - - // clang-format off - - // 16 components, 16x6 bit=12 bytes - const __m128i bit_6v = - _mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3); - const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v); - - // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F - // 00 01 02 03 - const __m256i shuffle_mask = _mm256_setr_epi16( - 0xFF00, 0x0100, 0x0201, 0xFF02, - 0xFF03, 0x0403, 0x0504, 0xFF05, - 0xFF06, 0x0706, 0x0807, 0xFF08, - 0xFF09, 0x0A09, 0x0B0A, 0xFF0B); - const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask); - - // 0: xxxxxxxx xx543210 - // 1: xxxx5432 10xxxxxx - // 2: xxxxxx54 3210xxxx - // 3: xxxxxxxx 543210xx - const __m256i shift_right_v = _mm256_setr_epi16( - 0x0U, 0x6U, 0x4U, 0x2U, - 0x0U, 0x6U, 0x4U, 0x2U, - 0x0U, 0x6U, 0x4U, 0x2U, - 0x0U, 0x6U, 0x4U, 0x2U); - __m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v); - - // remove unneeded bits - shuffled_shifted = - _mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F)); - - // scale - const __m512 f8 = - _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted)); - const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f); - const __m512 one_255 = _mm512_set1_ps(1.f / 63.f); - return _mm512_fmadd_ps(f8, one_255, half_one_255); - - // clang-format on - } - -#elif defined(__AVX2__) - - /* Load 6 bytes that represent 8 6-bit values, return them as a - * 8*32 bit vector register */ - static FAISS_ALWAYS_INLINE __m256i load6(const uint16_t* code16) { - const __m128i perm = _mm_set_epi8( - -1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0); - const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0); - - // load 6 bytes - __m128i c1 = - _mm_set_epi16(0, 0, 0, 0, 0, code16[2], code16[1], code16[0]); - - // put in 8 * 32 bits - __m128i c2 = _mm_shuffle_epi8(c1, perm); - __m256i c3 = _mm256_cvtepi16_epi32(c2); - - // shift and mask out useless bits - __m256i c4 = _mm256_srlv_epi32(c3, shifts); - __m256i c5 = _mm256_and_si256(_mm256_set1_epi32(63), c4); - return c5; - } - - static FAISS_ALWAYS_INLINE __m256 - decode_8_components(const uint8_t* code, int i) { - // // Faster code for Intel CPUs or AMD Zen3+, just keeping it here - // // for the reference, maybe, it becomes used oned day. - // const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3); - // const uint32_t* data32 = (const uint32_t*)data16; - // const uint64_t val = *data32 + ((uint64_t)data16[2] << 32); - // const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL); - // const __m128i i8 = _mm_set1_epi64x(vext); - // const __m256i i32 = _mm256_cvtepi8_epi32(i8); - // const __m256 f8 = _mm256_cvtepi32_ps(i32); - // const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f); - // const __m256 one_255 = _mm256_set1_ps(1.f / 63.f); - // return _mm256_fmadd_ps(f8, one_255, half_one_255); - - __m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3)); - __m256 f8 = _mm256_cvtepi32_ps(i8); - // this could also be done with bit manipulations but it is - // not obviously faster - const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f); - const __m256 one_255 = _mm256_set1_ps(1.f / 63.f); - return _mm256_fmadd_ps(f8, one_255, half_one_255); - } - -#endif - -#ifdef USE_NEON - static FAISS_ALWAYS_INLINE float32x4x2_t - decode_8_components(const uint8_t* code, int i) { - float32_t result[8] = {}; - for (size_t j = 0; j < 8; j++) { - result[j] = decode_component(code, i + j); - } - float32x4_t res1 = vld1q_f32(result); - float32x4_t res2 = vld1q_f32(result + 4); - return {res1, res2}; - } -#endif -}; - -/******************************************************************* - * Quantizer: normalizes scalar vector components, then passes them - * through a codec - *******************************************************************/ - -enum class QuantizerTemplateScaling { UNIFORM = 0, NON_UNIFORM = 1 }; - -template -struct QuantizerTemplate {}; - -template -struct QuantizerTemplate - : ScalarQuantizer::SQuantizer { - const size_t d; - const float vmin, vdiff; - - QuantizerTemplate(size_t d, const std::vector& trained) - : d(d), vmin(trained[0]), vdiff(trained[1]) {} - - void encode_vector(const float* x, uint8_t* code) const final { - for (size_t i = 0; i < d; i++) { - float xi = 0; - if (vdiff != 0) { - xi = (x[i] - vmin) / vdiff; - if (xi < 0) { - xi = 0; - } - if (xi > 1.0) { - xi = 1.0; - } - } - Codec::encode_component(xi, code, i); - } - } - - void decode_vector(const uint8_t* code, float* x) const final { - for (size_t i = 0; i < d; i++) { - float xi = Codec::decode_component(code, i); - x[i] = vmin + xi * vdiff; - } - } - - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - float xi = Codec::decode_component(code, i); - return vmin + xi * vdiff; - } -}; - -#if defined(__AVX512F__) - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate( - d, - trained) {} - - FAISS_ALWAYS_INLINE __m512 - reconstruct_16_components(const uint8_t* code, int i) const { - __m512 xi = Codec::decode_16_components(code, i); - return _mm512_fmadd_ps( - xi, _mm512_set1_ps(this->vdiff), _mm512_set1_ps(this->vmin)); - } -}; - -#elif defined(__AVX2__) - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate( - d, - trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m256 xi = Codec::decode_8_components(code, i); - return _mm256_fmadd_ps( - xi, _mm256_set1_ps(this->vdiff), _mm256_set1_ps(this->vmin)); - } -}; - -#endif - -#ifdef USE_NEON - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate( - d, - trained) {} - - FAISS_ALWAYS_INLINE float32x4x2_t - reconstruct_8_components(const uint8_t* code, int i) const { - float32x4x2_t xi = Codec::decode_8_components(code, i); - return {vfmaq_f32( - vdupq_n_f32(this->vmin), - xi.val[0], - vdupq_n_f32(this->vdiff)), - vfmaq_f32( - vdupq_n_f32(this->vmin), - xi.val[1], - vdupq_n_f32(this->vdiff))}; - } -}; - -#endif - -template -struct QuantizerTemplate - : ScalarQuantizer::SQuantizer { - const size_t d; - const float *vmin, *vdiff; - - QuantizerTemplate(size_t d, const std::vector& trained) - : d(d), vmin(trained.data()), vdiff(trained.data() + d) {} - - void encode_vector(const float* x, uint8_t* code) const final { - for (size_t i = 0; i < d; i++) { - float xi = 0; - if (vdiff[i] != 0) { - xi = (x[i] - vmin[i]) / vdiff[i]; - if (xi < 0) { - xi = 0; - } - if (xi > 1.0) { - xi = 1.0; - } - } - Codec::encode_component(xi, code, i); - } - } - - void decode_vector(const uint8_t* code, float* x) const final { - for (size_t i = 0; i < d; i++) { - float xi = Codec::decode_component(code, i); - x[i] = vmin[i] + xi * vdiff[i]; - } - } - - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - float xi = Codec::decode_component(code, i); - return vmin[i] + xi * vdiff[i]; - } -}; - -#if defined(__AVX512F__) - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate< - Codec, - QuantizerTemplateScaling::NON_UNIFORM, - 1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m512 - reconstruct_16_components(const uint8_t* code, int i) const { - __m512 xi = Codec::decode_16_components(code, i); - return _mm512_fmadd_ps( - xi, - _mm512_loadu_ps(this->vdiff + i), - _mm512_loadu_ps(this->vmin + i)); - } -}; - -#elif defined(__AVX2__) - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate< - Codec, - QuantizerTemplateScaling::NON_UNIFORM, - 1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m256 xi = Codec::decode_8_components(code, i); - return _mm256_fmadd_ps( - xi, - _mm256_loadu_ps(this->vdiff + i), - _mm256_loadu_ps(this->vmin + i)); - } -}; - -#endif - -#ifdef USE_NEON - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate< - Codec, - QuantizerTemplateScaling::NON_UNIFORM, - 1>(d, trained) {} - - FAISS_ALWAYS_INLINE float32x4x2_t - reconstruct_8_components(const uint8_t* code, int i) const { - float32x4x2_t xi = Codec::decode_8_components(code, i); - - float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i); - float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i); - - return {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]), - vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])}; - } -}; - -#endif - -/******************************************************************* - * FP16 quantizer - *******************************************************************/ - -template -struct QuantizerFP16 {}; - -template <> -struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer { - const size_t d; - - QuantizerFP16(size_t d, const std::vector& /* unused */) : d(d) {} - - void encode_vector(const float* x, uint8_t* code) const final { - for (size_t i = 0; i < d; i++) { - ((uint16_t*)code)[i] = encode_fp16(x[i]); - } - } - - void decode_vector(const uint8_t* code, float* x) const final { - for (size_t i = 0; i < d; i++) { - x[i] = decode_fp16(((uint16_t*)code)[i]); - } - } - - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - return decode_fp16(((uint16_t*)code)[i]); - } -}; - -#if defined(USE_AVX512_F16C) - -template <> -struct QuantizerFP16<16> : QuantizerFP16<1> { - QuantizerFP16(size_t d, const std::vector& trained) - : QuantizerFP16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m512 - reconstruct_16_components(const uint8_t* code, int i) const { - __m256i codei = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); - return _mm512_cvtph_ps(codei); - } -}; - -#endif - -#if defined(USE_F16C) - -template <> -struct QuantizerFP16<8> : QuantizerFP16<1> { - QuantizerFP16(size_t d, const std::vector& trained) - : QuantizerFP16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i)); - return _mm256_cvtph_ps(codei); - } -}; - -#endif - -#ifdef USE_NEON - -template <> -struct QuantizerFP16<8> : QuantizerFP16<1> { - QuantizerFP16(size_t d, const std::vector& trained) - : QuantizerFP16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE float32x4x2_t - reconstruct_8_components(const uint8_t* code, int i) const { - uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); - return {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])), - vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))}; - } -}; -#endif - -/******************************************************************* - * BF16 quantizer - *******************************************************************/ - -template -struct QuantizerBF16 {}; - -template <> -struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer { - const size_t d; - - QuantizerBF16(size_t d, const std::vector& /* unused */) : d(d) {} - - void encode_vector(const float* x, uint8_t* code) const final { - for (size_t i = 0; i < d; i++) { - ((uint16_t*)code)[i] = encode_bf16(x[i]); - } - } - - void decode_vector(const uint8_t* code, float* x) const final { - for (size_t i = 0; i < d; i++) { - x[i] = decode_bf16(((uint16_t*)code)[i]); - } - } - - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - return decode_bf16(((uint16_t*)code)[i]); - } -}; - -#if defined(__AVX512F__) - -template <> -struct QuantizerBF16<16> : QuantizerBF16<1> { - QuantizerBF16(size_t d, const std::vector& trained) - : QuantizerBF16<1>(d, trained) {} - FAISS_ALWAYS_INLINE __m512 - reconstruct_16_components(const uint8_t* code, int i) const { - __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); - __m512i code_512i = _mm512_cvtepu16_epi32(code_256i); - code_512i = _mm512_slli_epi32(code_512i, 16); - return _mm512_castsi512_ps(code_512i); - } -}; - -#elif defined(__AVX2__) - -template <> -struct QuantizerBF16<8> : QuantizerBF16<1> { - QuantizerBF16(size_t d, const std::vector& trained) - : QuantizerBF16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i)); - __m256i code_256i = _mm256_cvtepu16_epi32(code_128i); - code_256i = _mm256_slli_epi32(code_256i, 16); - return _mm256_castsi256_ps(code_256i); - } -}; - -#endif - -#ifdef USE_NEON - -template <> -struct QuantizerBF16<8> : QuantizerBF16<1> { - QuantizerBF16(size_t d, const std::vector& trained) - : QuantizerBF16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE float32x4x2_t - reconstruct_8_components(const uint8_t* code, int i) const { - uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); - return {vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(codei.val[0]), 16)), - vreinterpretq_f32_u32( - vshlq_n_u32(vmovl_u16(codei.val[1]), 16))}; - } -}; -#endif - -/******************************************************************* - * 8bit_direct quantizer - *******************************************************************/ - -template -struct Quantizer8bitDirect {}; - -template <> -struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer { - const size_t d; - - Quantizer8bitDirect(size_t d, const std::vector& /* unused */) - : d(d) {} - - void encode_vector(const float* x, uint8_t* code) const final { - for (size_t i = 0; i < d; i++) { - code[i] = (uint8_t)x[i]; - } - } - - void decode_vector(const uint8_t* code, float* x) const final { - for (size_t i = 0; i < d; i++) { - x[i] = code[i]; - } - } - - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - return code[i]; - } -}; - -#if defined(__AVX512F__) - -template <> -struct Quantizer8bitDirect<16> : Quantizer8bitDirect<1> { - Quantizer8bitDirect(size_t d, const std::vector& trained) - : Quantizer8bitDirect<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m512 - reconstruct_16_components(const uint8_t* code, int i) const { - __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 - __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 - return _mm512_cvtepi32_ps(y16); // 16 * float32 - } -}; - -#elif defined(__AVX2__) - -template <> -struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { - Quantizer8bitDirect(size_t d, const std::vector& trained) - : Quantizer8bitDirect<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 - __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 - return _mm256_cvtepi32_ps(y8); // 8 * float32 - } -}; - -#endif - -#ifdef USE_NEON - -template <> -struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { - Quantizer8bitDirect(size_t d, const std::vector& trained) - : Quantizer8bitDirect<1>(d, trained) {} - - FAISS_ALWAYS_INLINE float32x4x2_t - reconstruct_8_components(const uint8_t* code, int i) const { - uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); - uint16x8_t y8 = vmovl_u8(x8); - uint16x4_t y8_0 = vget_low_u16(y8); - uint16x4_t y8_1 = vget_high_u16(y8); - - // convert uint16 -> uint32 -> fp32 - return {vcvtq_f32_u32(vmovl_u16(y8_0)), vcvtq_f32_u32(vmovl_u16(y8_1))}; - } -}; - -#endif - -/******************************************************************* - * 8bit_direct_signed quantizer - *******************************************************************/ - -template -struct Quantizer8bitDirectSigned {}; - -template <> -struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer { - const size_t d; - - Quantizer8bitDirectSigned(size_t d, const std::vector& /* unused */) - : d(d) {} - - void encode_vector(const float* x, uint8_t* code) const final { - for (size_t i = 0; i < d; i++) { - code[i] = (uint8_t)(x[i] + 128); - } - } - - void decode_vector(const uint8_t* code, float* x) const final { - for (size_t i = 0; i < d; i++) { - x[i] = code[i] - 128; - } - } - - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - return code[i] - 128; - } -}; - -#if defined(__AVX512F__) - -template <> -struct Quantizer8bitDirectSigned<16> : Quantizer8bitDirectSigned<1> { - Quantizer8bitDirectSigned(size_t d, const std::vector& trained) - : Quantizer8bitDirectSigned<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m512 - reconstruct_16_components(const uint8_t* code, int i) const { - __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 - __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 - __m512i c16 = _mm512_set1_epi32(128); - __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes - return _mm512_cvtepi32_ps(z16); // 16 * float32 - } -}; - -#elif defined(__AVX2__) - -template <> -struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { - Quantizer8bitDirectSigned(size_t d, const std::vector& trained) - : Quantizer8bitDirectSigned<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 - __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 - __m256i c8 = _mm256_set1_epi32(128); - __m256i z8 = _mm256_sub_epi32(y8, c8); // subtract 128 from all lanes - return _mm256_cvtepi32_ps(z8); // 8 * float32 - } -}; - -#endif - -#ifdef USE_NEON - -template <> -struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { - Quantizer8bitDirectSigned(size_t d, const std::vector& trained) - : Quantizer8bitDirectSigned<1>(d, trained) {} - - FAISS_ALWAYS_INLINE float32x4x2_t - reconstruct_8_components(const uint8_t* code, int i) const { - uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); - uint16x8_t y8 = vmovl_u8(x8); // convert uint8 -> uint16 - uint16x4_t y8_0 = vget_low_u16(y8); - uint16x4_t y8_1 = vget_high_u16(y8); - - float32x4_t z8_0 = vcvtq_f32_u32( - vmovl_u16(y8_0)); // convert uint16 -> uint32 -> fp32 - float32x4_t z8_1 = vcvtq_f32_u32(vmovl_u16(y8_1)); - - // subtract 128 to convert into signed numbers - return {vsubq_f32(z8_0, vmovq_n_f32(128.0)), - vsubq_f32(z8_1, vmovq_n_f32(128.0))}; - } -}; - -#endif - -template -ScalarQuantizer::SQuantizer* select_quantizer_1( - QuantizerType qtype, - size_t d, - const std::vector& trained) { - switch (qtype) { - case ScalarQuantizer::QT_8bit: - return new QuantizerTemplate< - Codec8bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_6bit: - return new QuantizerTemplate< - Codec6bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_4bit: - return new QuantizerTemplate< - Codec4bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_8bit_uniform: - return new QuantizerTemplate< - Codec8bit, - QuantizerTemplateScaling::UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_4bit_uniform: - return new QuantizerTemplate< - Codec4bit, - QuantizerTemplateScaling::UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_fp16: - return new QuantizerFP16(d, trained); - case ScalarQuantizer::QT_bf16: - return new QuantizerBF16(d, trained); - case ScalarQuantizer::QT_8bit_direct: - return new Quantizer8bitDirect(d, trained); - case ScalarQuantizer::QT_8bit_direct_signed: - return new Quantizer8bitDirectSigned(d, trained); - } - FAISS_THROW_MSG("unknown qtype"); -} - -/******************************************************************* - * Quantizer range training - */ - -static float sqr(float x) { - return x * x; -} - -void train_Uniform( - RangeStat rs, - float rs_arg, - idx_t n, - int k, - const float* x, - std::vector& trained) { - trained.resize(2); - float& vmin = trained[0]; - float& vmax = trained[1]; - - if (rs == ScalarQuantizer::RS_minmax) { - vmin = HUGE_VAL; - vmax = -HUGE_VAL; - for (size_t i = 0; i < n; i++) { - if (x[i] < vmin) - vmin = x[i]; - if (x[i] > vmax) - vmax = x[i]; - } - float vexp = (vmax - vmin) * rs_arg; - vmin -= vexp; - vmax += vexp; - } else if (rs == ScalarQuantizer::RS_meanstd) { - double sum = 0, sum2 = 0; - for (size_t i = 0; i < n; i++) { - sum += x[i]; - sum2 += x[i] * x[i]; - } - float mean = sum / n; - float var = sum2 / n - mean * mean; - float std = var <= 0 ? 1.0 : sqrt(var); - - vmin = mean - std * rs_arg; - vmax = mean + std * rs_arg; - } else if (rs == ScalarQuantizer::RS_quantiles) { - std::vector x_copy(n); - memcpy(x_copy.data(), x, n * sizeof(*x)); - // TODO just do a quickselect - std::sort(x_copy.begin(), x_copy.end()); - int o = int(rs_arg * n); - if (o < 0) - o = 0; - if (o > n - o) - o = n / 2; - vmin = x_copy[o]; - vmax = x_copy[n - 1 - o]; - - } else if (rs == ScalarQuantizer::RS_optim) { - float a, b; - float sx = 0; - { - vmin = HUGE_VAL, vmax = -HUGE_VAL; - for (size_t i = 0; i < n; i++) { - if (x[i] < vmin) - vmin = x[i]; - if (x[i] > vmax) - vmax = x[i]; - sx += x[i]; - } - b = vmin; - a = (vmax - vmin) / (k - 1); - } - int verbose = false; - int niter = 2000; - float last_err = -1; - int iter_last_err = 0; - for (int it = 0; it < niter; it++) { - float sn = 0, sn2 = 0, sxn = 0, err1 = 0; - - for (idx_t i = 0; i < n; i++) { - float xi = x[i]; - float ni = floor((xi - b) / a + 0.5); - if (ni < 0) - ni = 0; - if (ni >= k) - ni = k - 1; - err1 += sqr(xi - (ni * a + b)); - sn += ni; - sn2 += ni * ni; - sxn += ni * xi; - } - - if (err1 == last_err) { - iter_last_err++; - if (iter_last_err == 16) - break; - } else { - last_err = err1; - iter_last_err = 0; - } - - float det = sqr(sn) - sn2 * n; - - b = (sn * sxn - sn2 * sx) / det; - a = (sn * sx - n * sxn) / det; - if (verbose) { - printf("it %d, err1=%g \r", it, err1); - fflush(stdout); - } - } - if (verbose) - printf("\n"); - - vmin = b; - vmax = b + a * (k - 1); - - } else { - FAISS_THROW_MSG("Invalid qtype"); - } - vmax -= vmin; -} - -void train_NonUniform( - RangeStat rs, - float rs_arg, - idx_t n, - int d, - int k, - const float* x, - std::vector& trained) { - trained.resize(2 * d); - float* vmin = trained.data(); - float* vmax = trained.data() + d; - if (rs == ScalarQuantizer::RS_minmax) { - memcpy(vmin, x, sizeof(*x) * d); - memcpy(vmax, x, sizeof(*x) * d); - for (size_t i = 1; i < n; i++) { - const float* xi = x + i * d; - for (size_t j = 0; j < d; j++) { - if (xi[j] < vmin[j]) - vmin[j] = xi[j]; - if (xi[j] > vmax[j]) - vmax[j] = xi[j]; - } - } - float* vdiff = vmax; - for (size_t j = 0; j < d; j++) { - float vexp = (vmax[j] - vmin[j]) * rs_arg; - vmin[j] -= vexp; - vmax[j] += vexp; - vdiff[j] = vmax[j] - vmin[j]; - } - } else { - // transpose - std::vector xt(n * d); - for (size_t i = 1; i < n; i++) { - const float* xi = x + i * d; - for (size_t j = 0; j < d; j++) { - xt[j * n + i] = xi[j]; - } - } - std::vector trained_d(2); -#pragma omp parallel for - for (int j = 0; j < d; j++) { - train_Uniform(rs, rs_arg, n, k, xt.data() + j * n, trained_d); - vmin[j] = trained_d[0]; - vmax[j] = trained_d[1]; - } - } -} - -/******************************************************************* - * Similarity: gets vector components and computes a similarity wrt. a - * query vector stored in the object. The data fields just encapsulate - * an accumulator. - */ - -template -struct SimilarityL2 {}; - -template <> -struct SimilarityL2<1> { - static constexpr int simdwidth = 1; - static constexpr MetricType metric_type = METRIC_L2; - - const float *y, *yi; - - explicit SimilarityL2(const float* y) : y(y) {} - - /******* scalar accumulator *******/ - - float accu; - - FAISS_ALWAYS_INLINE void begin() { - accu = 0; - yi = y; - } - - FAISS_ALWAYS_INLINE void add_component(float x) { - float tmp = *yi++ - x; - accu += tmp * tmp; - } - - FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) { - float tmp = x1 - x2; - accu += tmp * tmp; - } - - FAISS_ALWAYS_INLINE float result() { - return accu; - } -}; - -#if defined(__AVX512F__) - -template <> -struct SimilarityL2<16> { - static constexpr int simdwidth = 16; - static constexpr MetricType metric_type = METRIC_L2; - - const float *y, *yi; - - explicit SimilarityL2(const float* y) : y(y) {} - __m512 accu16; - - FAISS_ALWAYS_INLINE void begin_16() { - accu16 = _mm512_setzero_ps(); - yi = y; - } - - FAISS_ALWAYS_INLINE void add_16_components(__m512 x) { - __m512 yiv = _mm512_loadu_ps(yi); - yi += 16; - __m512 tmp = _mm512_sub_ps(yiv, x); - accu16 = _mm512_fmadd_ps(tmp, tmp, accu16); - } - - FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x, __m512 y_2) { - __m512 tmp = _mm512_sub_ps(y_2, x); - accu16 = _mm512_fmadd_ps(tmp, tmp, accu16); - } - - FAISS_ALWAYS_INLINE float result_16() { - // performs better than dividing into _mm256 and adding - return _mm512_reduce_add_ps(accu16); - } -}; - -#elif defined(__AVX2__) - -template <> -struct SimilarityL2<8> { - static constexpr int simdwidth = 8; - static constexpr MetricType metric_type = METRIC_L2; - - const float *y, *yi; - - explicit SimilarityL2(const float* y) : y(y) {} - __m256 accu8; - - FAISS_ALWAYS_INLINE void begin_8() { - accu8 = _mm256_setzero_ps(); - yi = y; - } - - FAISS_ALWAYS_INLINE void add_8_components(__m256 x) { - __m256 yiv = _mm256_loadu_ps(yi); - yi += 8; - __m256 tmp = _mm256_sub_ps(yiv, x); - accu8 = _mm256_fmadd_ps(tmp, tmp, accu8); - } - - FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x, __m256 y_2) { - __m256 tmp = _mm256_sub_ps(y_2, x); - accu8 = _mm256_fmadd_ps(tmp, tmp, accu8); - } - - FAISS_ALWAYS_INLINE float result_8() { - const __m128 sum = _mm_add_ps( - _mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1)); - const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); - const __m128 v1 = _mm_add_ps(sum, v0); - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - const __m128 v3 = _mm_add_ps(v1, v2); - return _mm_cvtss_f32(v3); - } -}; - -#endif - -#ifdef USE_NEON -template <> -struct SimilarityL2<8> { - static constexpr int simdwidth = 8; - static constexpr MetricType metric_type = METRIC_L2; - - const float *y, *yi; - explicit SimilarityL2(const float* y) : y(y) {} - float32x4x2_t accu8; - - FAISS_ALWAYS_INLINE void begin_8() { - accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; - yi = y; - } - - FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) { - float32x4x2_t yiv = vld1q_f32_x2(yi); - yi += 8; - - float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]); - float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]); - - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); - - accu8 = {accu8_0, accu8_1}; - } - - FAISS_ALWAYS_INLINE void add_8_components_2( - float32x4x2_t x, - float32x4x2_t y) { - float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]); - float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]); - - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); - - accu8 = {accu8_0, accu8_1}; - } - - FAISS_ALWAYS_INLINE float result_8() { - float32x4_t sum_0 = vpaddq_f32(accu8.val[0], accu8.val[0]); - float32x4_t sum_1 = vpaddq_f32(accu8.val[1], accu8.val[1]); - - float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0); - float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1); - return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0); - } -}; -#endif - -template -struct SimilarityIP {}; - -template <> -struct SimilarityIP<1> { - static constexpr int simdwidth = 1; - static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; - const float *y, *yi; - - float accu; - - explicit SimilarityIP(const float* y) : y(y) {} - - FAISS_ALWAYS_INLINE void begin() { - accu = 0; - yi = y; - } - - FAISS_ALWAYS_INLINE void add_component(float x) { - accu += *yi++ * x; - } - - FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) { - accu += x1 * x2; - } - - FAISS_ALWAYS_INLINE float result() { - return accu; - } -}; - -#if defined(__AVX512F__) - -template <> -struct SimilarityIP<16> { - static constexpr int simdwidth = 16; - static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; - - const float *y, *yi; - - float accu; - - explicit SimilarityIP(const float* y) : y(y) {} - - __m512 accu16; - - FAISS_ALWAYS_INLINE void begin_16() { - accu16 = _mm512_setzero_ps(); - yi = y; - } - - FAISS_ALWAYS_INLINE void add_16_components(__m512 x) { - __m512 yiv = _mm512_loadu_ps(yi); - yi += 16; - accu16 = _mm512_fmadd_ps(yiv, x, accu16); - } - - FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x1, __m512 x2) { - accu16 = _mm512_fmadd_ps(x1, x2, accu16); - } - - FAISS_ALWAYS_INLINE float result_16() { - // performs better than dividing into _mm256 and adding - return _mm512_reduce_add_ps(accu16); - } -}; - -#elif defined(__AVX2__) - -template <> -struct SimilarityIP<8> { - static constexpr int simdwidth = 8; - static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; - - const float *y, *yi; - - float accu; - - explicit SimilarityIP(const float* y) : y(y) {} - - __m256 accu8; - - FAISS_ALWAYS_INLINE void begin_8() { - accu8 = _mm256_setzero_ps(); - yi = y; - } - - FAISS_ALWAYS_INLINE void add_8_components(__m256 x) { - __m256 yiv = _mm256_loadu_ps(yi); - yi += 8; - accu8 = _mm256_fmadd_ps(yiv, x, accu8); - } - - FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x1, __m256 x2) { - accu8 = _mm256_fmadd_ps(x1, x2, accu8); - } - - FAISS_ALWAYS_INLINE float result_8() { - const __m128 sum = _mm_add_ps( - _mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1)); - const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); - const __m128 v1 = _mm_add_ps(sum, v0); - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - const __m128 v3 = _mm_add_ps(v1, v2); - return _mm_cvtss_f32(v3); - } -}; -#endif - -#ifdef USE_NEON - -template <> -struct SimilarityIP<8> { - static constexpr int simdwidth = 8; - static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; +#include - const float *y, *yi; +#include - explicit SimilarityIP(const float* y) : y(y) {} - float32x4x2_t accu8; +#include +#include - FAISS_ALWAYS_INLINE void begin_8() { - accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; - yi = y; - } +#include - FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) { - float32x4x2_t yiv = vld1q_f32_x2(yi); - yi += 8; +#include - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]); - accu8 = {accu8_0, accu8_1}; - } +#include +#include +#include +#include +#include - FAISS_ALWAYS_INLINE void add_8_components_2( - float32x4x2_t x1, - float32x4x2_t x2) { - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]); - accu8 = {accu8_0, accu8_1}; - } +/******************************************************************* + * ScalarQuantizer implementation + * + * The main source of complexity is to support combinations of 4 + * variants without incurring runtime tests or virtual function calls: + * + * - 4 / 6 / 8 bits per code component + * - uniform / non-uniform + * - IP / L2 distance search + * - scalar / SIMD distance computation + * + * The appropriate Quantizer object is returned via select_quantizer + * that hides the template mess. + ********************************************************************/ - FAISS_ALWAYS_INLINE float result_8() { - float32x4x2_t sum = { - vpaddq_f32(accu8.val[0], accu8.val[0]), - vpaddq_f32(accu8.val[1], accu8.val[1])}; +/******************************************************************* + * Codec: converts between values in [0, 1] and an index in a code + * array. The "i" parameter is the vector component index (not byte + * index). + */ - float32x4x2_t sum2 = { - vpaddq_f32(sum.val[0], sum.val[0]), - vpaddq_f32(sum.val[1], sum.val[1])}; - return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0); - } -}; -#endif +#include /******************************************************************* - * DistanceComputer: combines a similarity and a quantizer to do - * code-to-vector or code-to-code comparisons + * Quantizer: normalizes scalar vector components, then passes them + * through a codec *******************************************************************/ -template -struct DCTemplate : SQDistanceComputer {}; - -template -struct DCTemplate : SQDistanceComputer { - using Sim = Similarity; - - Quantizer quant; - - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} - - float compute_distance(const float* x, const uint8_t* code) const { - Similarity sim(x); - sim.begin(); - for (size_t i = 0; i < quant.d; i++) { - float xi = quant.reconstruct_component(code, i); - sim.add_component(xi); - } - return sim.result(); - } - - float compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - Similarity sim(nullptr); - sim.begin(); - for (size_t i = 0; i < quant.d; i++) { - float x1 = quant.reconstruct_component(code1, i); - float x2 = quant.reconstruct_component(code2, i); - sim.add_component_2(x1, x2); - } - return sim.result(); - } - - void set_query(const float* x) final { - q = x; - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_distance(q, code); - } -}; - -#if defined(USE_AVX512_F16C) - -template -struct DCTemplate - : SQDistanceComputer { // Update to handle 16 lanes - using Sim = Similarity; - - Quantizer quant; - - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} - - float compute_distance(const float* x, const uint8_t* code) const { - Similarity sim(x); - sim.begin_16(); - for (size_t i = 0; i < quant.d; i += 16) { - __m512 xi = quant.reconstruct_16_components(code, i); - sim.add_16_components(xi); - } - return sim.result_16(); - } - - float compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - Similarity sim(nullptr); - sim.begin_16(); - for (size_t i = 0; i < quant.d; i += 16) { - __m512 x1 = quant.reconstruct_16_components(code1, i); - __m512 x2 = quant.reconstruct_16_components(code2, i); - sim.add_16_components_2(x1, x2); - } - return sim.result_16(); - } - - void set_query(const float* x) final { - q = x; - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_distance(q, code); - } -}; - -#elif defined(USE_F16C) - -template -struct DCTemplate : SQDistanceComputer { - using Sim = Similarity; - - Quantizer quant; - - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} - - float compute_distance(const float* x, const uint8_t* code) const { - Similarity sim(x); - sim.begin_8(); - for (size_t i = 0; i < quant.d; i += 8) { - __m256 xi = quant.reconstruct_8_components(code, i); - sim.add_8_components(xi); - } - return sim.result_8(); - } - - float compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - Similarity sim(nullptr); - sim.begin_8(); - for (size_t i = 0; i < quant.d; i += 8) { - __m256 x1 = quant.reconstruct_8_components(code1, i); - __m256 x2 = quant.reconstruct_8_components(code2, i); - sim.add_8_components_2(x1, x2); - } - return sim.result_8(); - } - - void set_query(const float* x) final { - q = x; - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_distance(q, code); - } -}; - -#endif - -#ifdef USE_NEON - -template -struct DCTemplate : SQDistanceComputer { - using Sim = Similarity; - - Quantizer quant; - - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} - float compute_distance(const float* x, const uint8_t* code) const { - Similarity sim(x); - sim.begin_8(); - for (size_t i = 0; i < quant.d; i += 8) { - float32x4x2_t xi = quant.reconstruct_8_components(code, i); - sim.add_8_components(xi); - } - return sim.result_8(); - } - - float compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - Similarity sim(nullptr); - sim.begin_8(); - for (size_t i = 0; i < quant.d; i += 8) { - float32x4x2_t x1 = quant.reconstruct_8_components(code1, i); - float32x4x2_t x2 = quant.reconstruct_8_components(code2, i); - sim.add_8_components_2(x1, x2); - } - return sim.result_8(); - } - - void set_query(const float* x) final { - q = x; - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_distance(q, code); - } -}; -#endif +#include /******************************************************************* - * DistanceComputerByte: computes distances in the integer domain + * Similarity: gets vector components and computes a similarity wrt. a + * query vector stored in the object. The data fields just encapsulate + * an accumulator. + * DistanceComputer: combines a similarity and a quantizer to do + * code-to-vector or code-to-code comparisons *******************************************************************/ -template -struct DistanceComputerByte : SQDistanceComputer {}; - -template -struct DistanceComputerByte : SQDistanceComputer { - using Sim = Similarity; - - int d; - std::vector tmp; - - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} - - int compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - int accu = 0; - for (int i = 0; i < d; i++) { - if (Sim::metric_type == METRIC_INNER_PRODUCT) { - accu += int(code1[i]) * code2[i]; - } else { - int diff = int(code1[i]) - code2[i]; - accu += diff * diff; - } - } - return accu; - } - - void set_query(const float* x) final { - for (int i = 0; i < d; i++) { - tmp[i] = int(x[i]); - } - } - - int compute_distance(const float* x, const uint8_t* code) { - set_query(x); - return compute_code_distance(tmp.data(), code); - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_code_distance(tmp.data(), code); - } -}; - -#if defined(__AVX512F__) - -template -struct DistanceComputerByte : SQDistanceComputer { - using Sim = Similarity; - - int d; - std::vector tmp; - - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} - - int compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - __m512i accu = _mm512_setzero_si512(); - for (int i = 0; i < d; i += 32) { // Process 32 bytes at a time - __m512i c1 = _mm512_cvtepu8_epi16( - _mm256_loadu_si256((__m256i*)(code1 + i))); - __m512i c2 = _mm512_cvtepu8_epi16( - _mm256_loadu_si256((__m256i*)(code2 + i))); - __m512i prod32; - if (Sim::metric_type == METRIC_INNER_PRODUCT) { - prod32 = _mm512_madd_epi16(c1, c2); - } else { - __m512i diff = _mm512_sub_epi16(c1, c2); - prod32 = _mm512_madd_epi16(diff, diff); - } - accu = _mm512_add_epi32(accu, prod32); - } - // Horizontally add elements of accu - return _mm512_reduce_add_epi32(accu); - } - - void set_query(const float* x) final { - for (int i = 0; i < d; i++) { - tmp[i] = int(x[i]); - } - } - - int compute_distance(const float* x, const uint8_t* code) { - set_query(x); - return compute_code_distance(tmp.data(), code); - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_code_distance(tmp.data(), code); - } -}; - -#elif defined(__AVX2__) - -template -struct DistanceComputerByte : SQDistanceComputer { - using Sim = Similarity; - - int d; - std::vector tmp; - - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} - - int compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - // __m256i accu = _mm256_setzero_ps (); - __m256i accu = _mm256_setzero_si256(); - for (int i = 0; i < d; i += 16) { - // load 16 bytes, convert to 16 uint16_t - __m256i c1 = _mm256_cvtepu8_epi16( - _mm_loadu_si128((__m128i*)(code1 + i))); - __m256i c2 = _mm256_cvtepu8_epi16( - _mm_loadu_si128((__m128i*)(code2 + i))); - __m256i prod32; - if (Sim::metric_type == METRIC_INNER_PRODUCT) { - prod32 = _mm256_madd_epi16(c1, c2); - } else { - __m256i diff = _mm256_sub_epi16(c1, c2); - prod32 = _mm256_madd_epi16(diff, diff); - } - accu = _mm256_add_epi32(accu, prod32); - } - __m128i sum = _mm256_extractf128_si256(accu, 0); - sum = _mm_add_epi32(sum, _mm256_extractf128_si256(accu, 1)); - sum = _mm_hadd_epi32(sum, sum); - sum = _mm_hadd_epi32(sum, sum); - return _mm_cvtsi128_si32(sum); - } - - void set_query(const float* x) final { - /* - for (int i = 0; i < d; i += 8) { - __m256 xi = _mm256_loadu_ps (x + i); - __m256i ci = _mm256_cvtps_epi32(xi); - */ - for (int i = 0; i < d; i++) { - tmp[i] = int(x[i]); - } - } - - int compute_distance(const float* x, const uint8_t* code) { - set_query(x); - return compute_code_distance(tmp.data(), code); - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_code_distance(tmp.data(), code); - } -}; - -#endif - -#ifdef USE_NEON - -template -struct DistanceComputerByte : SQDistanceComputer { - using Sim = Similarity; - - int d; - std::vector tmp; - - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} - - int compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - int accu = 0; - for (int i = 0; i < d; i++) { - if (Sim::metric_type == METRIC_INNER_PRODUCT) { - accu += int(code1[i]) * code2[i]; - } else { - int diff = int(code1[i]) - code2[i]; - accu += diff * diff; - } - } - return accu; - } - - void set_query(const float* x) final { - for (int i = 0; i < d; i++) { - tmp[i] = int(x[i]); - } - } - - int compute_distance(const float* x, const uint8_t* code) { - set_query(x); - return compute_code_distance(tmp.data(), code); - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_code_distance(tmp.data(), code); - } -}; - -#endif +#include /******************************************************************* - * select_distance_computer: runtime selection of template - * specialization + * InvertedListScanner: scans series of codes and keeps the best ones *******************************************************************/ -template -SQDistanceComputer* select_distance_computer( - QuantizerType qtype, - size_t d, - const std::vector& trained) { - constexpr int SIMDWIDTH = Sim::simdwidth; - switch (qtype) { - case ScalarQuantizer::QT_8bit_uniform: - return new DCTemplate< - QuantizerTemplate< - Codec8bit, - QuantizerTemplateScaling::UNIFORM, - SIMDWIDTH>, - Sim, - SIMDWIDTH>(d, trained); - - case ScalarQuantizer::QT_4bit_uniform: - return new DCTemplate< - QuantizerTemplate< - Codec4bit, - QuantizerTemplateScaling::UNIFORM, - SIMDWIDTH>, - Sim, - SIMDWIDTH>(d, trained); - - case ScalarQuantizer::QT_8bit: - return new DCTemplate< - QuantizerTemplate< - Codec8bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>, - Sim, - SIMDWIDTH>(d, trained); - - case ScalarQuantizer::QT_6bit: - return new DCTemplate< - QuantizerTemplate< - Codec6bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>, - Sim, - SIMDWIDTH>(d, trained); - - case ScalarQuantizer::QT_4bit: - return new DCTemplate< - QuantizerTemplate< - Codec4bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>, - Sim, - SIMDWIDTH>(d, trained); - - case ScalarQuantizer::QT_fp16: - return new DCTemplate, Sim, SIMDWIDTH>( - d, trained); - - case ScalarQuantizer::QT_bf16: - return new DCTemplate, Sim, SIMDWIDTH>( - d, trained); - - case ScalarQuantizer::QT_8bit_direct: -#if defined(__AVX512F__) - if (d % 32 == 0) { - return new DistanceComputerByte(d, trained); - } else -#elif defined(__AVX2__) - if (d % 16 == 0) { - return new DistanceComputerByte(d, trained); - } else -#endif - { - return new DCTemplate< - Quantizer8bitDirect, - Sim, - SIMDWIDTH>(d, trained); - } - case ScalarQuantizer::QT_8bit_direct_signed: - return new DCTemplate< - Quantizer8bitDirectSigned, - Sim, - SIMDWIDTH>(d, trained); - } - FAISS_THROW_MSG("unknown qtype"); - return nullptr; -} +#include -} // anonymous namespace +namespace faiss { +using namespace scalar_quantizer; /******************************************************************* * ScalarQuantizer implementation @@ -2046,18 +152,19 @@ void ScalarQuantizer::train(size_t n, const float* x) { } ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const { -#if defined(USE_AVX512_F16C) - if (d % 16 == 0) { - return select_quantizer_1<16>(qtype, d, trained); + // here we can't just dispatch because the SIMD code works only on certain + // vector sizes +#ifdef COMPILE_SIMD_AVX512F + if (d % 16 == 0 && SIMDConfig::level >= SIMDLevel::AVX512F) { + return select_quantizer_1(qtype, d, trained); } else -#elif defined(USE_F16C) || defined(USE_NEON) - if (d % 8 == 0) { - return select_quantizer_1<8>(qtype, d, trained); +#endif +#ifdef COMPILE_SIMD_AVX2 + if (d % 8 == 0 && SIMDConfig::level >= SIMDLevel::AVX2) { + return select_quantizer_1(qtype, d, trained); } else #endif - { - return select_quantizer_1<1>(qtype, d, trained); - } + return select_quantizer_1(qtype, d, trained); } void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n) @@ -2080,33 +187,20 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const { SQDistanceComputer* ScalarQuantizer::get_distance_computer( MetricType metric) const { - FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT); -#if defined(USE_AVX512_F16C) - if (d % 16 == 0) { - if (metric == METRIC_L2) { - return select_distance_computer>( - qtype, d, trained); - } else { - return select_distance_computer>( - qtype, d, trained); - } +#ifdef COMPILE_SIMD_AVX512F + if (d % 16 == 0 && SIMDConfig::level >= SIMDLevel::AVX512F) { + return select_distance_computer_1( + metric, qtype, d, trained); } else -#elif defined(USE_F16C) || defined(USE_NEON) - if (d % 8 == 0) { - if (metric == METRIC_L2) { - return select_distance_computer>(qtype, d, trained); - } else { - return select_distance_computer>(qtype, d, trained); - } +#endif +#ifdef COMPILE_SIMD_AVX2 + if (d % 8 == 0 && SIMDConfig::level >= SIMDLevel::AVX2) { + return select_distance_computer_1( + metric, qtype, d, trained); } else #endif - { - if (metric == METRIC_L2) { - return select_distance_computer>(qtype, d, trained); - } else { - return select_distance_computer>(qtype, d, trained); - } - } + return select_distance_computer_1( + metric, qtype, d, trained); } /******************************************************************* @@ -2116,366 +210,26 @@ SQDistanceComputer* ScalarQuantizer::get_distance_computer( * IndexScalarQuantizer as well. ********************************************************************/ -namespace { - -template -struct IVFSQScannerIP : InvertedListScanner { - DCClass dc; - bool by_residual; - - float accu0; /// added to all distances - - IVFSQScannerIP( - int d, - const std::vector& trained, - size_t code_size, - bool store_pairs, - const IDSelector* sel, - bool by_residual) - : dc(d, trained), by_residual(by_residual), accu0(0) { - this->store_pairs = store_pairs; - this->sel = sel; - this->code_size = code_size; - this->keep_max = true; - } - - void set_query(const float* query) override { - dc.set_query(query); - } - - void set_list(idx_t list_no, float coarse_dis) override { - this->list_no = list_no; - accu0 = by_residual ? coarse_dis : 0; - } - - float distance_to_code(const uint8_t* code) const final { - return accu0 + dc.query_to_code(code); - } - - size_t scan_codes( - size_t list_size, - const uint8_t* codes, - const idx_t* ids, - float* simi, - idx_t* idxi, - size_t k) const override { - size_t nup = 0; - - for (size_t j = 0; j < list_size; j++, codes += code_size) { - if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { - continue; - } - - float accu = accu0 + dc.query_to_code(codes); - - if (accu > simi[0]) { - int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - minheap_replace_top(k, simi, idxi, accu, id); - nup++; - } - } - return nup; - } - - void scan_codes_range( - size_t list_size, - const uint8_t* codes, - const idx_t* ids, - float radius, - RangeQueryResult& res) const override { - for (size_t j = 0; j < list_size; j++, codes += code_size) { - if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { - continue; - } - - float accu = accu0 + dc.query_to_code(codes); - if (accu > radius) { - int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - res.add(accu, id); - } - } - } -}; - -/* use_sel = 0: don't check selector - * = 1: check on ids[j] - * = 2: check in j directly (normally ids is nullptr and store_pairs) - */ -template -struct IVFSQScannerL2 : InvertedListScanner { - DCClass dc; - - bool by_residual; - const Index* quantizer; - const float* x; /// current query - - std::vector tmp; - - IVFSQScannerL2( - int d, - const std::vector& trained, - size_t code_size, - const Index* quantizer, - bool store_pairs, - const IDSelector* sel, - bool by_residual) - : dc(d, trained), - by_residual(by_residual), - quantizer(quantizer), - x(nullptr), - tmp(d) { - this->store_pairs = store_pairs; - this->sel = sel; - this->code_size = code_size; - } - - void set_query(const float* query) override { - x = query; - if (!quantizer) { - dc.set_query(query); - } - } - - void set_list(idx_t list_no, float /*coarse_dis*/) override { - this->list_no = list_no; - if (by_residual) { - // shift of x_in wrt centroid - quantizer->compute_residual(x, tmp.data(), list_no); - dc.set_query(tmp.data()); - } else { - dc.set_query(x); - } - } - - float distance_to_code(const uint8_t* code) const final { - return dc.query_to_code(code); - } - - size_t scan_codes( - size_t list_size, - const uint8_t* codes, - const idx_t* ids, - float* simi, - idx_t* idxi, - size_t k) const override { - size_t nup = 0; - for (size_t j = 0; j < list_size; j++, codes += code_size) { - if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { - continue; - } - - float dis = dc.query_to_code(codes); - - if (dis < simi[0]) { - int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - maxheap_replace_top(k, simi, idxi, dis, id); - nup++; - } - } - return nup; - } - - void scan_codes_range( - size_t list_size, - const uint8_t* codes, - const idx_t* ids, - float radius, - RangeQueryResult& res) const override { - for (size_t j = 0; j < list_size; j++, codes += code_size) { - if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { - continue; - } - - float dis = dc.query_to_code(codes); - if (dis < radius) { - int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - res.add(dis, id); - } - } - } -}; - -template -InvertedListScanner* sel3_InvertedListScanner( - const ScalarQuantizer* sq, - const Index* quantizer, - bool store_pairs, - const IDSelector* sel, - bool r) { - if (DCClass::Sim::metric_type == METRIC_L2) { - return new IVFSQScannerL2( - sq->d, - sq->trained, - sq->code_size, - quantizer, - store_pairs, - sel, - r); - } else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) { - return new IVFSQScannerIP( - sq->d, sq->trained, sq->code_size, store_pairs, sel, r); - } else { - FAISS_THROW_MSG("unsupported metric type"); - } -} - -template -InvertedListScanner* sel2_InvertedListScanner( - const ScalarQuantizer* sq, - const Index* quantizer, - bool store_pairs, - const IDSelector* sel, - bool r) { - if (sel) { - if (store_pairs) { - return sel3_InvertedListScanner( - sq, quantizer, store_pairs, sel, r); - } else { - return sel3_InvertedListScanner( - sq, quantizer, store_pairs, sel, r); - } - } else { - return sel3_InvertedListScanner( - sq, quantizer, store_pairs, sel, r); - } -} - -template -InvertedListScanner* sel12_InvertedListScanner( - const ScalarQuantizer* sq, - const Index* quantizer, - bool store_pairs, - const IDSelector* sel, - bool r) { - constexpr int SIMDWIDTH = Similarity::simdwidth; - using QuantizerClass = QuantizerTemplate; - using DCClass = DCTemplate; - return sel2_InvertedListScanner( - sq, quantizer, store_pairs, sel, r); -} - -template -InvertedListScanner* sel1_InvertedListScanner( - const ScalarQuantizer* sq, - const Index* quantizer, - bool store_pairs, - const IDSelector* sel, - bool r) { - constexpr int SIMDWIDTH = Similarity::simdwidth; - switch (sq->qtype) { - case ScalarQuantizer::QT_8bit_uniform: - return sel12_InvertedListScanner< - Similarity, - Codec8bit, - QuantizerTemplateScaling::UNIFORM>( - sq, quantizer, store_pairs, sel, r); - case ScalarQuantizer::QT_4bit_uniform: - return sel12_InvertedListScanner< - Similarity, - Codec4bit, - QuantizerTemplateScaling::UNIFORM>( - sq, quantizer, store_pairs, sel, r); - case ScalarQuantizer::QT_8bit: - return sel12_InvertedListScanner< - Similarity, - Codec8bit, - QuantizerTemplateScaling::NON_UNIFORM>( - sq, quantizer, store_pairs, sel, r); - case ScalarQuantizer::QT_4bit: - return sel12_InvertedListScanner< - Similarity, - Codec4bit, - QuantizerTemplateScaling::NON_UNIFORM>( - sq, quantizer, store_pairs, sel, r); - case ScalarQuantizer::QT_6bit: - return sel12_InvertedListScanner< - Similarity, - Codec6bit, - QuantizerTemplateScaling::NON_UNIFORM>( - sq, quantizer, store_pairs, sel, r); - case ScalarQuantizer::QT_fp16: - return sel2_InvertedListScanner, - Similarity, - SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); - case ScalarQuantizer::QT_bf16: - return sel2_InvertedListScanner, - Similarity, - SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); - case ScalarQuantizer::QT_8bit_direct: -#if defined(__AVX512F__) - if (sq->d % 32 == 0) { - return sel2_InvertedListScanner< - DistanceComputerByte>( - sq, quantizer, store_pairs, sel, r); - } else -#elif defined(__AVX2__) - if (sq->d % 16 == 0) { - return sel2_InvertedListScanner< - DistanceComputerByte>( - sq, quantizer, store_pairs, sel, r); - } else -#endif - { - return sel2_InvertedListScanner, - Similarity, - SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); - } - case ScalarQuantizer::QT_8bit_direct_signed: - return sel2_InvertedListScanner, - Similarity, - SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); - } - - FAISS_THROW_MSG("unknown qtype"); - return nullptr; -} - -template -InvertedListScanner* sel0_InvertedListScanner( - MetricType mt, - const ScalarQuantizer* sq, - const Index* quantizer, - bool store_pairs, - const IDSelector* sel, - bool by_residual) { - if (mt == METRIC_L2) { - return sel1_InvertedListScanner>( - sq, quantizer, store_pairs, sel, by_residual); - } else if (mt == METRIC_INNER_PRODUCT) { - return sel1_InvertedListScanner>( - sq, quantizer, store_pairs, sel, by_residual); - } else { - FAISS_THROW_MSG("unsupported metric type"); - } -} - -} // anonymous namespace - InvertedListScanner* ScalarQuantizer::select_InvertedListScanner( MetricType mt, const Index* quantizer, bool store_pairs, const IDSelector* sel, bool by_residual) const { -#if defined(USE_AVX512_F16C) - if (d % 16 == 0) { - return sel0_InvertedListScanner<16>( +#ifdef COMPILE_SIMD_AVX512F + if (d % 16 == 0 && SIMDConfig::level >= SIMDLevel::AVX512F) { + return sel0_InvertedListScanner( mt, this, quantizer, store_pairs, sel, by_residual); } else -#elif defined(USE_F16C) || defined(USE_NEON) - if (d % 8 == 0) { - return sel0_InvertedListScanner<8>( +#endif +#ifdef COMPILE_SIMD_AVX2 + if (d % 8 == 0 && SIMDConfig::level >= SIMDLevel::AVX2) { + return sel0_InvertedListScanner( mt, this, quantizer, store_pairs, sel, by_residual); } else #endif - { - return sel0_InvertedListScanner<1>( + return sel0_InvertedListScanner( mt, this, quantizer, store_pairs, sel, by_residual); - } } } // namespace faiss diff --git a/faiss/impl/ScalarQuantizer.h b/faiss/impl/ScalarQuantizer.h index c1f4f98f63..279938443a 100644 --- a/faiss/impl/ScalarQuantizer.h +++ b/faiss/impl/ScalarQuantizer.h @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - #pragma once #include diff --git a/faiss/impl/code_distance/code_distance-avx2.h b/faiss/impl/code_distance/code_distance-avx2.h deleted file mode 100644 index 53380b6e46..0000000000 --- a/faiss/impl/code_distance/code_distance-avx2.h +++ /dev/null @@ -1,534 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#ifdef __AVX2__ - -#include - -#include - -#include -#include - -// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=78782 -#if defined(__GNUC__) && __GNUC__ < 9 -#define _mm_loadu_si64(x) (_mm_loadl_epi64((__m128i_u*)x)) -#endif - -namespace { - -inline float horizontal_sum(const __m128 v) { - const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); - const __m128 v1 = _mm_add_ps(v, v0); - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - const __m128 v3 = _mm_add_ps(v1, v2); - return _mm_cvtss_f32(v3); -} - -// Computes a horizontal sum over an __m256 register -inline float horizontal_sum(const __m256 v) { - const __m128 v0 = - _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); - return horizontal_sum(v0); -} - -// processes a single code for M=4, ksub=256, nbits=8 -float inline distance_single_code_avx2_pqdecoder8_m4( - // precomputed distances, layout (4, 256) - const float* sim_table, - const uint8_t* code) { - float result = 0; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - const __m128i vksub = _mm_set1_epi32(ksub); - __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); - offsets_0 = _mm_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m128 partialSum; - - // load 4 uint8 values - const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code)); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m128i idx1 = _mm_cvtepu8_epi32(mm1); - - // add offsets - const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m128 collected = - _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSum = collected; - } - - // horizontal sum for partialSum - result = horizontal_sum(partialSum); - return result; -} - -// processes a single code for M=8, ksub=256, nbits=8 -float inline distance_single_code_avx2_pqdecoder8_m8( - // precomputed distances, layout (8, 256) - const float* sim_table, - const uint8_t* code) { - float result = 0; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSum; - - // load 8 uint8 values - const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); - - // add offsets - const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = - _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSum = collected; - } - - // horizontal sum for partialSum - result = horizontal_sum(partialSum); - return result; -} - -// processes four codes for M=4, ksub=256, nbits=8 -inline void distance_four_codes_avx2_pqdecoder8_m4( - // precomputed distances, layout (4, 256) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - constexpr intptr_t N = 4; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - // process 8 values - const __m128i vksub = _mm_set1_epi32(ksub); - __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); - offsets_0 = _mm_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m128 partialSums[N]; - - // load 4 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0)); - mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1)); - mm1[2] = _mm_cvtsi32_si128(*((const int32_t*)code2)); - mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3)); - - for (intptr_t j = 0; j < N; j++) { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); - - // gather 4 values, similar to 4 operations of tab[idx] - __m128 collected = - _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = collected; - } - - // horizontal sum for partialSum - result0 = horizontal_sum(partialSums[0]); - result1 = horizontal_sum(partialSums[1]); - result2 = horizontal_sum(partialSums[2]); - result3 = horizontal_sum(partialSums[3]); -} - -// processes four codes for M=8, ksub=256, nbits=8 -inline void distance_four_codes_avx2_pqdecoder8_m8( - // precomputed distances, layout (8, 256) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - constexpr intptr_t N = 4; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - // process 8 values - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSums[N]; - - // load 8 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si64((const __m128i_u*)code0); - mm1[1] = _mm_loadu_si64((const __m128i_u*)code1); - mm1[2] = _mm_loadu_si64((const __m128i_u*)code2); - mm1[3] = _mm_loadu_si64((const __m128i_u*)code3); - - for (intptr_t j = 0; j < N; j++) { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = - _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = collected; - } - - // horizontal sum for partialSum - result0 = horizontal_sum(partialSums[0]); - result1 = horizontal_sum(partialSums[1]); - result2 = horizontal_sum(partialSums[2]); - result3 = horizontal_sum(partialSums[3]); -} - -} // namespace - -namespace faiss { - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code) { - // default implementation - return distance_single_code_generic(M, nbits, sim_table, code); -} - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code) { - if (M == 4) { - return distance_single_code_avx2_pqdecoder8_m4(sim_table, code); - } - if (M == 8) { - return distance_single_code_avx2_pqdecoder8_m8(sim_table, code); - } - - float result = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSum = _mm256_setzero_ps(); - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - const __m128i mm1 = _mm_loadu_si128((const __m128i_u*)(code + m)); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - tab += ksub * 8; - - // collect partial sums - partialSum = _mm256_add_ps(partialSum, collected); - } - - // move high 8 uint8 to low ones - const __m128i mm2 = _mm_unpackhi_epi64(mm1, _mm_setzero_si128()); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - tab += ksub * 8; - - // collect partial sums - partialSum = _mm256_add_ps(partialSum, collected); - } - } - - // horizontal sum for partialSum - result += horizontal_sum(partialSum); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder(code + m, nbits); - - for (; m < M; m++) { - result += tab[decoder.decode()]; - tab += ksub; - } - } - - return result; -} - -template -typename std::enable_if::value, void>:: - type - distance_four_codes_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_generic( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -// Combines 4 operations of distance_single_code() -template -typename std::enable_if::value, void>::type -distance_four_codes_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - if (M == 4) { - distance_four_codes_avx2_pqdecoder8_m4( - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); - return; - } - if (M == 8) { - distance_four_codes_avx2_pqdecoder8_m8( - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); - return; - } - - result0 = 0; - result1 = 0; - result2 = 0; - result3 = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - constexpr intptr_t N = 4; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSums[N]; - for (intptr_t j = 0; j < N; j++) { - partialSums[j] = _mm256_setzero_ps(); - } - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); - mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); - mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); - mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); - - // process first 8 codes - for (intptr_t j = 0; j < N; j++) { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm256_add_ps(partialSums[j], collected); - } - tab += ksub * 8; - - // process next 8 codes - for (intptr_t j = 0; j < N; j++) { - // move high 8 uint8 to low ones - const __m128i mm2 = - _mm_unpackhi_epi64(mm1[j], _mm_setzero_si128()); - - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm256_add_ps(partialSums[j], collected); - } - - tab += ksub * 8; - } - - // horizontal sum for partialSum - result0 += horizontal_sum(partialSums[0]); - result1 += horizontal_sum(partialSums[1]); - result2 += horizontal_sum(partialSums[2]); - result3 += horizontal_sum(partialSums[3]); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder0(code0 + m, nbits); - PQDecoder8 decoder1(code1 + m, nbits); - PQDecoder8 decoder2(code2 + m, nbits); - PQDecoder8 decoder3(code3 + m, nbits); - for (; m < M; m++) { - result0 += tab[decoder0.decode()]; - result1 += tab[decoder1.decode()]; - result2 += tab[decoder2.decode()]; - result3 += tab[decoder3.decode()]; - tab += ksub; - } - } -} - -} // namespace faiss - -#endif diff --git a/faiss/impl/code_distance/code_distance-avx512.h b/faiss/impl/code_distance/code_distance-avx512.h deleted file mode 100644 index d05c41c19c..0000000000 --- a/faiss/impl/code_distance/code_distance-avx512.h +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#ifdef __AVX512F__ - -#include - -#include - -#include -#include - -namespace faiss { - -// According to experiments, the AVX-512 version may be SLOWER than -// the AVX2 version, which is somewhat unexpected. -// This version is not used for now, but it may be used later. -// -// TODO: test for AMD CPUs. - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code) { - // default implementation - return distance_single_code_generic(M, nbits, sim_table, code); -} - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code0) { - float result0 = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - constexpr intptr_t N = 1; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - const __m512i vksub = _mm512_set1_epi32(ksub); - __m512i offsets_0 = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m512 partialSums[N]; - for (intptr_t j = 0; j < N; j++) { - partialSums[j] = _mm512_setzero_ps(); - } - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); - - // process first 8 codes - for (intptr_t j = 0; j < N; j++) { - const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m512i indices_to_read_from = - _mm512_add_epi32(idx1, offsets_0); - - // gather 16 values, similar to 16 operations of tab[idx] - __m512 collected = _mm512_i32gather_ps( - indices_to_read_from, tab, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm512_add_ps(partialSums[j], collected); - } - tab += ksub * 16; - } - - // horizontal sum for partialSum - result0 += _mm512_reduce_add_ps(partialSums[0]); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder0(code0 + m, nbits); - for (; m < M; m++) { - result0 += tab[decoder0.decode()]; - tab += ksub; - } - } - - return result0; -} - -template -typename std::enable_if::value, void>:: - type - distance_four_codes_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_generic( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -// Combines 4 operations of distance_single_code() -template -typename std::enable_if::value, void>::type -distance_four_codes_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - result0 = 0; - result1 = 0; - result2 = 0; - result3 = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - constexpr intptr_t N = 4; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - const __m512i vksub = _mm512_set1_epi32(ksub); - __m512i offsets_0 = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m512 partialSums[N]; - for (intptr_t j = 0; j < N; j++) { - partialSums[j] = _mm512_setzero_ps(); - } - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); - mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); - mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); - mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); - - // process first 8 codes - for (intptr_t j = 0; j < N; j++) { - const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m512i indices_to_read_from = - _mm512_add_epi32(idx1, offsets_0); - - // gather 16 values, similar to 16 operations of tab[idx] - __m512 collected = _mm512_i32gather_ps( - indices_to_read_from, tab, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm512_add_ps(partialSums[j], collected); - } - tab += ksub * 16; - } - - // horizontal sum for partialSum - result0 += _mm512_reduce_add_ps(partialSums[0]); - result1 += _mm512_reduce_add_ps(partialSums[1]); - result2 += _mm512_reduce_add_ps(partialSums[2]); - result3 += _mm512_reduce_add_ps(partialSums[3]); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder0(code0 + m, nbits); - PQDecoder8 decoder1(code1 + m, nbits); - PQDecoder8 decoder2(code2 + m, nbits); - PQDecoder8 decoder3(code3 + m, nbits); - for (; m < M; m++) { - result0 += tab[decoder0.decode()]; - result1 += tab[decoder1.decode()]; - result2 += tab[decoder2.decode()]; - result3 += tab[decoder3.decode()]; - tab += ksub; - } - } -} - -} // namespace faiss - -#endif diff --git a/faiss/impl/code_distance/code_distance-generic.h b/faiss/impl/code_distance/code_distance-generic.h deleted file mode 100644 index c02551c415..0000000000 --- a/faiss/impl/code_distance/code_distance-generic.h +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include -#include - -namespace faiss { - -/// Returns the distance to a single code. -template -inline float distance_single_code_generic( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - PQDecoderT decoder(code, nbits); - const size_t ksub = 1 << nbits; - - const float* tab = sim_table; - float result = 0; - - for (size_t m = 0; m < M; m++) { - result += tab[decoder.decode()]; - tab += ksub; - } - - return result; -} - -/// Combines 4 operations of distance_single_code() -/// General-purpose version. -template -inline void distance_four_codes_generic( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - PQDecoderT decoder0(code0, nbits); - PQDecoderT decoder1(code1, nbits); - PQDecoderT decoder2(code2, nbits); - PQDecoderT decoder3(code3, nbits); - const size_t ksub = 1 << nbits; - - const float* tab = sim_table; - result0 = 0; - result1 = 0; - result2 = 0; - result3 = 0; - - for (size_t m = 0; m < M; m++) { - result0 += tab[decoder0.decode()]; - result1 += tab[decoder1.decode()]; - result2 += tab[decoder2.decode()]; - result3 += tab[decoder3.decode()]; - tab += ksub; - } -} - -} // namespace faiss diff --git a/faiss/impl/code_distance/code_distance.h b/faiss/impl/code_distance/code_distance.h deleted file mode 100644 index 8f29abda97..0000000000 --- a/faiss/impl/code_distance/code_distance.h +++ /dev/null @@ -1,186 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include - -// This directory contains functions to compute a distance -// from a given PQ code to a query vector, given that the -// distances to a query vector for pq.M codebooks are precomputed. -// -// The code was originally the part of IndexIVFPQ.cpp. -// The baseline implementation can be found in -// code_distance-generic.h, distance_single_code_generic(). - -// The reason for this somewhat unusual structure is that -// custom implementations may need to fall off to generic -// implementation in certain cases. So, say, avx2 header file -// needs to reference the generic header file. This is -// why the names of the functions for custom implementations -// have this _generic or _avx2 suffix. - -#ifdef __AVX2__ - -#include - -namespace faiss { - -template -inline float distance_single_code( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - return distance_single_code_avx2(M, nbits, sim_table, code); -} - -template -inline void distance_four_codes( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_avx2( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -} // namespace faiss - -#elif defined(__ARM_FEATURE_SVE) - -#include - -namespace faiss { - -template -inline float distance_single_code( - // the product quantizer - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - return distance_single_code_sve(M, nbits, sim_table, code); -} - -template -inline void distance_four_codes( - // the product quantizer - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_sve( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -} // namespace faiss - -#else - -#include - -namespace faiss { - -template -inline float distance_single_code( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - return distance_single_code_generic(M, nbits, sim_table, code); -} - -template -inline void distance_four_codes( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_generic( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -} // namespace faiss - -#endif diff --git a/faiss/impl/LookupTableScaler.h b/faiss/impl/pq_4bit/LookupTableScaler.h similarity index 95% rename from faiss/impl/LookupTableScaler.h rename to faiss/impl/pq_4bit/LookupTableScaler.h index 015bcdee5a..922d058653 100644 --- a/faiss/impl/LookupTableScaler.h +++ b/faiss/impl/pq_4bit/LookupTableScaler.h @@ -23,6 +23,8 @@ namespace faiss { struct DummyScaler { static constexpr int nscale = 0; + explicit DummyScaler(int x = -1) {} + inline simd32uint8 lookup(const simd32uint8&, const simd32uint8&) const { FAISS_THROW_MSG("DummyScaler::lookup should not be called."); return simd32uint8(0); @@ -64,12 +66,12 @@ struct DummyScaler { /// consumes 2x4 bits to encode a norm as a scalar additive quantizer /// the norm is scaled because its range if larger than other components -struct NormTableScaler { +struct Scaler2x4bit { static constexpr int nscale = 2; int scale_int; simd16uint16 scale_simd; - explicit NormTableScaler(int scale) : scale_int(scale), scale_simd(scale) {} + explicit Scaler2x4bit(int scale) : scale_int(scale), scale_simd(scale) {} inline simd32uint8 lookup(const simd32uint8& lut, const simd32uint8& c) const { diff --git a/faiss/impl/pq_4bit/decompose_qbs.h b/faiss/impl/pq_4bit/decompose_qbs.h new file mode 100644 index 0000000000..9d35ad17a6 --- /dev/null +++ b/faiss/impl/pq_4bit/decompose_qbs.h @@ -0,0 +1,175 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// decompose q set of queries into fixed-size blocks. This code is common +// between 256 and 512-bit SIMD +// This is not standalone code, it is intended to be included in the kernels-X.h +// files. + +// handle at most 4 blocks of queries +template +void accumulate_q_4step( + size_t ntotal2, + int nsq, + const uint8_t* codes, + const uint8_t* LUT0, + ResultHandler& res, + const Scaler& scaler) { + constexpr SIMDLevel SL = ResultHandler::SL; + constexpr int Q1 = QBS & 15; + constexpr int Q2 = (QBS >> 4) & 15; + constexpr int Q3 = (QBS >> 8) & 15; + constexpr int Q4 = (QBS >> 12) & 15; + constexpr int SQ = Q1 + Q2 + Q3 + Q4; + + for (size_t j0 = 0; j0 < ntotal2; j0 += 32) { + FixedStorageHandler res2; + const uint8_t* LUT = LUT0; + kernel_accumulate_block(nsq, codes, LUT, res2, scaler); + LUT += Q1 * nsq * 16; + if (Q2 > 0) { + res2.set_block_origin(Q1, 0); + kernel_accumulate_block(nsq, codes, LUT, res2, scaler); + LUT += Q2 * nsq * 16; + } + if (Q3 > 0) { + res2.set_block_origin(Q1 + Q2, 0); + kernel_accumulate_block(nsq, codes, LUT, res2, scaler); + LUT += Q3 * nsq * 16; + } + if (Q4 > 0) { + res2.set_block_origin(Q1 + Q2 + Q3, 0); + kernel_accumulate_block(nsq, codes, LUT, res2, scaler); + } + res.set_block_origin(0, j0); + res2.to_other_handler(res); + codes += 32 * nsq / 2; + } +} + +template +void kernel_accumulate_block_loop( + size_t ntotal2, + int nsq, + const uint8_t* codes, + const uint8_t* LUT, + ResultHandler& res, + const Scaler& scaler) { + for (size_t j0 = 0; j0 < ntotal2; j0 += 32) { + res.set_block_origin(0, j0); + kernel_accumulate_block( + nsq, codes + j0 * nsq / 2, LUT, res, scaler); + } +} + +// non-template version of accumulate kernel -- dispatches dynamically +template +void accumulate( + int nq, + size_t ntotal2, + int nsq, + const uint8_t* codes, + const uint8_t* LUT, + ResultHandler& res, + const Scaler& scaler) { + assert(nsq % 2 == 0); + assert(is_aligned_pointer(codes)); + assert(is_aligned_pointer(LUT)); + +#define DISPATCH(NQ) \ + case NQ: \ + kernel_accumulate_block_loop( \ + ntotal2, nsq, codes, LUT, res, scaler); \ + return + + switch (nq) { + DISPATCH(1); + DISPATCH(2); + DISPATCH(3); + DISPATCH(4); + } + FAISS_THROW_FMT("accumulate nq=%d not instantiated", nq); + +#undef DISPATCH +} + +template +void pq4_accumulate_loop_qbs_fixed_scaler( + int qbs, + size_t ntotal2, + int nsq, + const uint8_t* codes, + const uint8_t* LUT0, + ResultHandler& res, + const Scaler& scaler) { + assert(nsq % 2 == 0); + assert(is_aligned_pointer(codes)); + assert(is_aligned_pointer(LUT0)); + + // try out optimized versions + switch (qbs) { +#define DISPATCH(QBS) \ + case QBS: \ + accumulate_q_4step(ntotal2, nsq, codes, LUT0, res, scaler); \ + return; + DISPATCH(0x3333); // 12 + + DISPATCH(0x2333); // 11 + DISPATCH(0x2233); // 10 + DISPATCH(0x333); // 9 + DISPATCH(0x2223); // 9 + DISPATCH(0x233); // 8 + DISPATCH(0x1223); // 8 + DISPATCH(0x223); // 7 + DISPATCH(0x34); // 7 + DISPATCH(0x133); // 7 + DISPATCH(0x6); // 6 + DISPATCH(0x33); // 6 + DISPATCH(0x123); // 6 + DISPATCH(0x222); // 6 + DISPATCH(0x23); // 5 + DISPATCH(0x5); // 5 + DISPATCH(0x13); // 4 + DISPATCH(0x22); // 4 + DISPATCH(0x4); // 4 + DISPATCH(0x3); // 3 + DISPATCH(0x21); // 3 + DISPATCH(0x2); // 2 + DISPATCH(0x1); // 1 +#undef DISPATCH + } + + // default implementation where qbs is not known at compile time + + for (size_t j0 = 0; j0 < ntotal2; j0 += 32) { + const uint8_t* LUT = LUT0; + int qi = qbs; + int i0 = 0; + while (qi) { + int nq = qi & 15; + qi >>= 4; + res.set_block_origin(i0, j0); +#define DISPATCH(NQ) \ + case NQ: \ + kernel_accumulate_block( \ + nsq, codes, LUT, res, scaler); \ + break + switch (nq) { + DISPATCH(1); + DISPATCH(2); + DISPATCH(3); + DISPATCH(4); +#undef DISPATCH + default: + FAISS_THROW_FMT("accumulate nq=%d not instantiated", nq); + } + i0 += nq; + LUT += nq * nsq * 16; + } + codes += 32 * nsq / 2; + } +} diff --git a/faiss/impl/pq_4bit/dispatching.h b/faiss/impl/pq_4bit/dispatching.h new file mode 100644 index 0000000000..5423362339 --- /dev/null +++ b/faiss/impl/pq_4bit/dispatching.h @@ -0,0 +1,166 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +/** This header contains functions that dispatch the runtime parameters to + * compile-time template parameters */ + +#include +#include +#include +#include + +#include + +namespace faiss { + +/** Mix-in class that manages both an SIMD result hander and offers the actual + * scanning routines. */ +template +struct ScannerMixIn : ResultHandler { + Scaler scaler; + + // args are forwarded to the ResutlHandler constructor + template + ScannerMixIn(int norm_scale, ConstructorTypes... args) + : ResultHandler(args...), scaler(norm_scale) {} + + void accumulate_loop( + int nq, + size_t nb, + int bbs, + int nsq, + const uint8_t* codes, + const uint8_t* LUT) override { + constexpr SIMDLevel SL = ResultHandler::SL; + pq4_accumulate_loop_fixed_scaler( + nq, nb, bbs, nsq, codes, LUT, *this, scaler); + } + + void accumulate_loop_qbs( + int qbs, + size_t nb, + int nsq, + const uint8_t* codes, + const uint8_t* LUT) override { + pq4_accumulate_loop_qbs_fixed_scaler( + qbs, nb, nsq, codes, LUT, *this, scaler); + } +}; + +// instantiate the ResultHandler part of the PQ4Scanner. The type of handler is +// determined by the function parameters (so make_handler_2 is overloaded +// several times) + +template +PQ4CodeScanner* make_handler_2( + int norm_scale, + bool use_reservoir, + size_t nq, + size_t ntotal, + size_t k, + float* dis, + int64_t* ids, + const IDSelector* sel) { + if (k == 1) { + return new ScannerMixIn< + SingleResultHandler, + Scaler>(norm_scale, nq, ntotal, dis, ids, sel); + } else if (use_reservoir) { + return new ScannerMixIn, Scaler>( + norm_scale, nq, ntotal, k, 2 * k, dis, ids, sel); + } else { + return new ScannerMixIn, Scaler>( + norm_scale, nq, ntotal, k, dis, ids, sel); + } +} + +template +PQ4CodeScanner* make_handler_2( + int norm_scale, + RangeSearchResult* rres, + float radius, + size_t ntotal, + const IDSelector* sel) { + return new ScannerMixIn, Scaler>( + norm_scale, rres, radius, ntotal, sel); +} + +template +PQ4CodeScanner* make_handler_2( + int norm_scale, + RangeSearchPartialResult* pres, + float radius, + size_t ntotal, + size_t q0, + size_t q1, + const IDSelector* sel) { + return new ScannerMixIn, Scaler>( + norm_scale, pres, radius, ntotal, q0, q1, sel); +} + +// this function dispatches runtime -> template parameters. It is generic for +// the different instances of make_handler_2. Be careful not to pass +// structs by references here becasue they will be copied by value not by ref +// (better use pointers). + +template +PQ4CodeScanner* make_pq4_scanner_1(bool is_max, int norm_scale, Types... args) { + if (is_max) { + using C = CMax; + if (norm_scale == -1) { + return make_handler_2( + norm_scale, args...); + } else { + return make_handler_2( + norm_scale, args...); + } + } else { + using C = CMin; + if (norm_scale == -1) { + return make_handler_2( + norm_scale, args...); + } else { + return make_handler_2( + norm_scale, args...); + } + } +} + +// make_pq4_scanner should not be instantiated automatically (even if the +// function is defined just above), because here is where the different SIMD +// versions become separate. + +// Because it is tedious to repleat the parameters all the time, define a few +// macros. this does not pollute the namespace too much because this is an +// internal header. +#define KNN_ARGS_LIST \ + bool is_max, int norm_scale, bool use_reservoir, idx_t nq, idx_t ntotal, \ + idx_t k, float *dis, idx_t *ids, const IDSelector *sel +#define KNN_ARGS_LIST_2 \ + is_max, norm_scale, use_reservoir, nq, ntotal, k, dis, ids, sel + +template +PQ4CodeScanner* make_pq4_scanner(KNN_ARGS_LIST); + +#define RRES_ARGS_LIST \ + bool is_max, int norm_scale, RangeSearchResult *rres, float radius, \ + idx_t ntotal, const IDSelector *sel +#define RRES_ARGS_LIST_2 is_max, norm_scale, rres, radius, ntotal, sel + +template +PQ4CodeScanner* make_pq4_scanner(RRES_ARGS_LIST); + +#define PRES_ARGS_LIST \ + bool is_max, int norm_scale, RangeSearchPartialResult *pres, float radius, \ + idx_t ntotal, idx_t i0, idx_t i1, const IDSelector *sel +#define PRES_ARGS_LIST_2 is_max, norm_scale, pres, radius, ntotal, i0, i1, sel + +template +PQ4CodeScanner* make_pq4_scanner(PRES_ARGS_LIST); + +} // namespace faiss diff --git a/faiss/impl/pq_4bit/impl-avx2.cpp b/faiss/impl/pq_4bit/impl-avx2.cpp new file mode 100644 index 0000000000..186b7e49a8 --- /dev/null +++ b/faiss/impl/pq_4bit/impl-avx2.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef __x86_64__ +#ifndef __AVX2__ +#error "this should be compiled with AVX2" +#endif +#endif + +#include +#include +#include +#include + +namespace faiss { + +template <> +PQ4CodeScanner* make_pq4_scanner(KNN_ARGS_LIST) { + return make_pq4_scanner_1(KNN_ARGS_LIST_2); +} + +template <> +PQ4CodeScanner* make_pq4_scanner(KNN_ARGS_LIST) { + return make_pq4_scanner_1(KNN_ARGS_LIST_2); +} + +template <> +PQ4CodeScanner* make_pq4_scanner(RRES_ARGS_LIST) { + return make_pq4_scanner_1(RRES_ARGS_LIST_2); +} + +template <> +PQ4CodeScanner* make_pq4_scanner(PRES_ARGS_LIST) { + return make_pq4_scanner_1(PRES_ARGS_LIST_2); +} + +} // namespace faiss diff --git a/faiss/impl/pq_4bit/impl-avx512.cpp b/faiss/impl/pq_4bit/impl-avx512.cpp new file mode 100644 index 0000000000..badb30d1fb --- /dev/null +++ b/faiss/impl/pq_4bit/impl-avx512.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#ifdef __x86_64__ +#ifndef __AVX512F__ +#error "this should be compiled with -mavx512f" +#endif +#endif + +namespace faiss { + +template <> +PQ4CodeScanner* make_pq4_scanner(KNN_ARGS_LIST) { + return make_pq4_scanner_1(KNN_ARGS_LIST_2); +} + +template <> +PQ4CodeScanner* make_pq4_scanner(KNN_ARGS_LIST) { + return make_pq4_scanner_1(KNN_ARGS_LIST_2); +} + +template <> +PQ4CodeScanner* make_pq4_scanner(RRES_ARGS_LIST) { + return make_pq4_scanner_1(RRES_ARGS_LIST_2); +} + +template <> +PQ4CodeScanner* make_pq4_scanner(PRES_ARGS_LIST) { + return make_pq4_scanner_1(PRES_ARGS_LIST_2); +} + +} // namespace faiss diff --git a/faiss/impl/pq4_fast_scan_search_1.cpp b/faiss/impl/pq_4bit/kernels_simd256.h similarity index 57% rename from faiss/impl/pq4_fast_scan_search_1.cpp rename to faiss/impl/pq_4bit/kernels_simd256.h index 5c7d797142..85a03ba15d 100644 --- a/faiss/impl/pq4_fast_scan_search_1.cpp +++ b/faiss/impl/pq_4bit/kernels_simd256.h @@ -5,30 +5,30 @@ * LICENSE file in the root directory of this source tree. */ -#include +#pragma once + +#include #include -#include -#include +#include +#include namespace faiss { using namespace simd_result_handlers; /*************************************************************** - * accumulation functions + * accumulation functions -- for bbs not necessarily 32 ***************************************************************/ -namespace { - /* * The computation kernel * It accumulates results for NQ queries and BB * 32 database elements - * writes results in a ResultHandler + * writes results in a ResultHandler. */ template -void kernel_accumulate_block( +void kernel_accumulate_block_bb( int nsq, const uint8_t* codes, const uint8_t* LUT, @@ -116,8 +116,8 @@ void kernel_accumulate_block( } } -template -void accumulate_fixed_blocks( +template +void accumulate_fixed_blocks_bb( size_t nb, int nsq, const uint8_t* codes, @@ -126,15 +126,15 @@ void accumulate_fixed_blocks( const Scaler& scaler) { constexpr int bbs = 32 * BB; for (size_t j0 = 0; j0 < nb; j0 += bbs) { - FixedStorageHandler res2; - kernel_accumulate_block(nsq, codes, LUT, res2, scaler); + FixedStorageHandler res2; + kernel_accumulate_block_bb(nsq, codes, LUT, res2, scaler); res.set_block_origin(0, j0); res2.to_other_handler(res); codes += bbs * nsq / 2; } } -template +template void pq4_accumulate_loop_fixed_scaler( int nq, size_t nb, @@ -149,9 +149,10 @@ void pq4_accumulate_loop_fixed_scaler( FAISS_THROW_IF_NOT(bbs % 32 == 0); FAISS_THROW_IF_NOT(nb % bbs == 0); -#define DISPATCH(NQ, BB) \ - case NQ * 1000 + BB: \ - accumulate_fixed_blocks(nb, nsq, codes, LUT, res, scaler); \ +#define DISPATCH(NQ, BB) \ + case NQ * 1000 + BB: \ + accumulate_fixed_blocks_bb( \ + nb, nsq, codes, LUT, res, scaler); \ break switch (nq * 1000 + bbs / 32) { @@ -170,55 +171,90 @@ void pq4_accumulate_loop_fixed_scaler( #undef DISPATCH } -template -void pq4_accumulate_loop_fixed_handler( - int nq, - size_t nb, - int bbs, +/*************************************************************** + * accumulation functions -- simplified for bbs=32 + ***************************************************************/ + +template +void kernel_accumulate_block( int nsq, const uint8_t* codes, const uint8_t* LUT, ResultHandler& res, - const NormTableScaler* scaler) { - if (scaler) { - pq4_accumulate_loop_fixed_scaler( - nq, nb, bbs, nsq, codes, LUT, res, *scaler); - } else { - DummyScaler dscaler; - pq4_accumulate_loop_fixed_scaler( - nq, nb, bbs, nsq, codes, LUT, res, dscaler); + const Scaler& scaler) { + // dummy alloc to keep the windows compiler happy + constexpr int NQA = NQ > 0 ? NQ : 1; + // distance accumulators + // layout: accu[q][b]: distance accumulator for vectors 8*b..8*b+7 + simd16uint16 accu[NQA][4]; + + for (int q = 0; q < NQ; q++) { + for (int b = 0; b < 4; b++) { + accu[q][b].clear(); + } } -} -struct Run_pq4_accumulate_loop { - template - void f(ResultHandler& res, - int nq, - size_t nb, - int bbs, - int nsq, - const uint8_t* codes, - const uint8_t* LUT, - const NormTableScaler* scaler) { - pq4_accumulate_loop_fixed_handler( - nq, nb, bbs, nsq, codes, LUT, res, scaler); + // _mm_prefetch(codes + 768, 0); + for (int sq = 0; sq < nsq - scaler.nscale; sq += 2) { + // prefetch + simd32uint8 c(codes); + codes += 32; + + simd32uint8 mask(0xf); + // shift op does not exist for int8... + simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask; + simd32uint8 clo = c & mask; + + for (int q = 0; q < NQ; q++) { + // load LUTs for 2 quantizers + simd32uint8 lut(LUT); + LUT += 32; + + simd32uint8 res0 = lut.lookup_2_lanes(clo); + simd32uint8 res1 = lut.lookup_2_lanes(chi); + + accu[q][0] += simd16uint16(res0); + accu[q][1] += simd16uint16(res0) >> 8; + + accu[q][2] += simd16uint16(res1); + accu[q][3] += simd16uint16(res1) >> 8; + } } -}; -} // anonymous namespace + for (int sq = 0; sq < scaler.nscale; sq += 2) { + // prefetch + simd32uint8 c(codes); + codes += 32; -void pq4_accumulate_loop( - int nq, - size_t nb, - int bbs, - int nsq, - const uint8_t* codes, - const uint8_t* LUT, - SIMDResultHandler& res, - const NormTableScaler* scaler) { - Run_pq4_accumulate_loop consumer; - dispatch_SIMDResultHandler( - res, consumer, nq, nb, bbs, nsq, codes, LUT, scaler); + simd32uint8 mask(0xf); + // shift op does not exist for int8... + simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask; + simd32uint8 clo = c & mask; + + for (int q = 0; q < NQ; q++) { + // load LUTs for 2 quantizers + simd32uint8 lut(LUT); + LUT += 32; + + simd32uint8 res0 = scaler.lookup(lut, clo); + accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..7 + accu[q][1] += scaler.scale_hi(res0); // handle vectors 8..15 + + simd32uint8 res1 = scaler.lookup(lut, chi); + accu[q][2] += scaler.scale_lo(res1); // handle vectors 16..23 + accu[q][3] += scaler.scale_hi(res1); // handle vectors 24..31 + } + } + + for (int q = 0; q < NQ; q++) { + accu[q][0] -= accu[q][1] << 8; + simd16uint16 dis0 = combine2x2(accu[q][0], accu[q][1]); + accu[q][2] -= accu[q][3] << 8; + simd16uint16 dis1 = combine2x2(accu[q][2], accu[q][3]); + res.handle(q, 0, dis0, dis1); + } } +#include + } // namespace faiss diff --git a/faiss/impl/pq4_fast_scan_search_qbs.cpp b/faiss/impl/pq_4bit/kernels_simd512.h similarity index 63% rename from faiss/impl/pq4_fast_scan_search_qbs.cpp rename to faiss/impl/pq_4bit/kernels_simd512.h index a9efe13fc1..5adac5e2e1 100644 --- a/faiss/impl/pq4_fast_scan_search_qbs.cpp +++ b/faiss/impl/pq_4bit/kernels_simd512.h @@ -4,116 +4,177 @@ * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +#pragma once -#include +#include #include -#include -#include -#include +#include +#include namespace faiss { -// declared in simd_result_handlers.h -bool simd_result_handlers_accept_virtual = true; - using namespace simd_result_handlers; -/************************************************************ - * Accumulation functions - ************************************************************/ - -namespace { +/*************************************************************** + * accumulation functions -- for bbs not necessarily 32 + * These functions are not specialized for 512 bit SIMD so the code is just + * copied from kernels_simd256.h + ***************************************************************/ /* * The computation kernel - * It accumulates results for NQ queries and 2 * 16 database elements - * writes results in a ResultHandler + * It accumulates results for NQ queries and BB * 32 database elements + * writes results in a ResultHandler. */ -#ifndef __AVX512F__ - -template -void kernel_accumulate_block( +template +void kernel_accumulate_block_bb( int nsq, const uint8_t* codes, const uint8_t* LUT, ResultHandler& res, const Scaler& scaler) { - // dummy alloc to keep the windows compiler happy - constexpr int NQA = NQ > 0 ? NQ : 1; // distance accumulators - // layout: accu[q][b]: distance accumulator for vectors 8*b..8*b+7 - simd16uint16 accu[NQA][4]; + simd16uint16 accu[NQ][BB][4]; for (int q = 0; q < NQ; q++) { - for (int b = 0; b < 4; b++) { - accu[q][b].clear(); + for (int b = 0; b < BB; b++) { + accu[q][b][0].clear(); + accu[q][b][1].clear(); + accu[q][b][2].clear(); + accu[q][b][3].clear(); } } - // _mm_prefetch(codes + 768, 0); for (int sq = 0; sq < nsq - scaler.nscale; sq += 2) { - // prefetch - simd32uint8 c(codes); - codes += 32; - - simd32uint8 mask(0xf); - // shift op does not exist for int8... - simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask; - simd32uint8 clo = c & mask; - + simd32uint8 lut_cache[NQ]; for (int q = 0; q < NQ; q++) { - // load LUTs for 2 quantizers - simd32uint8 lut(LUT); + lut_cache[q] = simd32uint8(LUT); LUT += 32; + } - simd32uint8 res0 = lut.lookup_2_lanes(clo); - simd32uint8 res1 = lut.lookup_2_lanes(chi); + for (int b = 0; b < BB; b++) { + simd32uint8 c = simd32uint8(codes); + codes += 32; + simd32uint8 mask(15); + simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask; + simd32uint8 clo = c & mask; + + for (int q = 0; q < NQ; q++) { + simd32uint8 lut = lut_cache[q]; + simd32uint8 res0 = lut.lookup_2_lanes(clo); + simd32uint8 res1 = lut.lookup_2_lanes(chi); - accu[q][0] += simd16uint16(res0); - accu[q][1] += simd16uint16(res0) >> 8; + accu[q][b][0] += simd16uint16(res0); + accu[q][b][1] += simd16uint16(res0) >> 8; - accu[q][2] += simd16uint16(res1); - accu[q][3] += simd16uint16(res1) >> 8; + accu[q][b][2] += simd16uint16(res1); + accu[q][b][3] += simd16uint16(res1) >> 8; + } } } for (int sq = 0; sq < scaler.nscale; sq += 2) { - // prefetch - simd32uint8 c(codes); - codes += 32; - - simd32uint8 mask(0xf); - // shift op does not exist for int8... - simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask; - simd32uint8 clo = c & mask; - + simd32uint8 lut_cache[NQ]; for (int q = 0; q < NQ; q++) { - // load LUTs for 2 quantizers - simd32uint8 lut(LUT); + lut_cache[q] = simd32uint8(LUT); LUT += 32; + } - simd32uint8 res0 = scaler.lookup(lut, clo); - accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..7 - accu[q][1] += scaler.scale_hi(res0); // handle vectors 8..15 + for (int b = 0; b < BB; b++) { + simd32uint8 c = simd32uint8(codes); + codes += 32; + simd32uint8 mask(15); + simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask; + simd32uint8 clo = c & mask; - simd32uint8 res1 = scaler.lookup(lut, chi); - accu[q][2] += scaler.scale_lo(res1); // handle vectors 16..23 - accu[q][3] += scaler.scale_hi(res1); // handle vectors 24..31 + for (int q = 0; q < NQ; q++) { + simd32uint8 lut = lut_cache[q]; + + simd32uint8 res0 = scaler.lookup(lut, clo); + accu[q][b][0] += scaler.scale_lo(res0); // handle vectors 0..7 + accu[q][b][1] += scaler.scale_hi(res0); // handle vectors 8..15 + + simd32uint8 res1 = scaler.lookup(lut, chi); + accu[q][b][2] += scaler.scale_lo(res1); // handle vectors 16..23 + accu[q][b][3] += + scaler.scale_hi(res1); // handle vectors 24..31 + } } } for (int q = 0; q < NQ; q++) { - accu[q][0] -= accu[q][1] << 8; - simd16uint16 dis0 = combine2x2(accu[q][0], accu[q][1]); - accu[q][2] -= accu[q][3] << 8; - simd16uint16 dis1 = combine2x2(accu[q][2], accu[q][3]); - res.handle(q, 0, dis0, dis1); + for (int b = 0; b < BB; b++) { + accu[q][b][0] -= accu[q][b][1] << 8; + simd16uint16 dis0 = combine2x2(accu[q][b][0], accu[q][b][1]); + + accu[q][b][2] -= accu[q][b][3] << 8; + simd16uint16 dis1 = combine2x2(accu[q][b][2], accu[q][b][3]); + + res.handle(q, b, dis0, dis1); + } + } +} + +template +void accumulate_fixed_blocks_bb( + size_t nb, + int nsq, + const uint8_t* codes, + const uint8_t* LUT, + ResultHandler& res, + const Scaler& scaler) { + constexpr int bbs = 32 * BB; + for (size_t j0 = 0; j0 < nb; j0 += bbs) { + FixedStorageHandler res2; + kernel_accumulate_block_bb(nsq, codes, LUT, res2, scaler); + res.set_block_origin(0, j0); + res2.to_other_handler(res); + codes += bbs * nsq / 2; } } -#else +template +void pq4_accumulate_loop_fixed_scaler( + int nq, + size_t nb, + int bbs, + int nsq, + const uint8_t* codes, + const uint8_t* LUT, + ResultHandler& res, + const Scaler& scaler) { + FAISS_THROW_IF_NOT(is_aligned_pointer(codes)); + FAISS_THROW_IF_NOT(is_aligned_pointer(LUT)); + FAISS_THROW_IF_NOT(bbs % 32 == 0); + FAISS_THROW_IF_NOT(nb % bbs == 0); + +#define DISPATCH(NQ, BB) \ + case NQ * 1000 + BB: \ + accumulate_fixed_blocks_bb( \ + nb, nsq, codes, LUT, res, scaler); \ + break + + switch (nq * 1000 + bbs / 32) { + DISPATCH(1, 1); + DISPATCH(1, 2); + DISPATCH(1, 3); + DISPATCH(1, 4); + DISPATCH(1, 5); + DISPATCH(2, 1); + DISPATCH(2, 2); + DISPATCH(3, 1); + DISPATCH(4, 1); + default: + FAISS_THROW_FMT("nq=%d bbs=%d not instantiated", nq, bbs); + } +#undef DISPATCH +} + +/*************************************************************** + * accumulation functions -- simplified for bbs=32 + ***************************************************************/ // a special version for NQ=1. // Despite the function being large in the text form, it compiles to a very @@ -556,248 +617,6 @@ void kernel_accumulate_block( } } -#endif - -// handle at most 4 blocks of queries -template -void accumulate_q_4step( - size_t ntotal2, - int nsq, - const uint8_t* codes, - const uint8_t* LUT0, - ResultHandler& res, - const Scaler& scaler) { - constexpr int Q1 = QBS & 15; - constexpr int Q2 = (QBS >> 4) & 15; - constexpr int Q3 = (QBS >> 8) & 15; - constexpr int Q4 = (QBS >> 12) & 15; - constexpr int SQ = Q1 + Q2 + Q3 + Q4; - - for (size_t j0 = 0; j0 < ntotal2; j0 += 32) { - FixedStorageHandler res2; - const uint8_t* LUT = LUT0; - kernel_accumulate_block(nsq, codes, LUT, res2, scaler); - LUT += Q1 * nsq * 16; - if (Q2 > 0) { - res2.set_block_origin(Q1, 0); - kernel_accumulate_block(nsq, codes, LUT, res2, scaler); - LUT += Q2 * nsq * 16; - } - if (Q3 > 0) { - res2.set_block_origin(Q1 + Q2, 0); - kernel_accumulate_block(nsq, codes, LUT, res2, scaler); - LUT += Q3 * nsq * 16; - } - if (Q4 > 0) { - res2.set_block_origin(Q1 + Q2 + Q3, 0); - kernel_accumulate_block(nsq, codes, LUT, res2, scaler); - } - res.set_block_origin(0, j0); - res2.to_other_handler(res); - codes += 32 * nsq / 2; - } -} - -template -void kernel_accumulate_block_loop( - size_t ntotal2, - int nsq, - const uint8_t* codes, - const uint8_t* LUT, - ResultHandler& res, - const Scaler& scaler) { - for (size_t j0 = 0; j0 < ntotal2; j0 += 32) { - res.set_block_origin(0, j0); - kernel_accumulate_block( - nsq, codes + j0 * nsq / 2, LUT, res, scaler); - } -} - -// non-template version of accumulate kernel -- dispatches dynamically -template -void accumulate( - int nq, - size_t ntotal2, - int nsq, - const uint8_t* codes, - const uint8_t* LUT, - ResultHandler& res, - const Scaler& scaler) { - assert(nsq % 2 == 0); - assert(is_aligned_pointer(codes)); - assert(is_aligned_pointer(LUT)); - -#define DISPATCH(NQ) \ - case NQ: \ - kernel_accumulate_block_loop( \ - ntotal2, nsq, codes, LUT, res, scaler); \ - return - - switch (nq) { - DISPATCH(1); - DISPATCH(2); - DISPATCH(3); - DISPATCH(4); - } - FAISS_THROW_FMT("accumulate nq=%d not instantiated", nq); - -#undef DISPATCH -} - -template -void pq4_accumulate_loop_qbs_fixed_scaler( - int qbs, - size_t ntotal2, - int nsq, - const uint8_t* codes, - const uint8_t* LUT0, - ResultHandler& res, - const Scaler& scaler) { - assert(nsq % 2 == 0); - assert(is_aligned_pointer(codes)); - assert(is_aligned_pointer(LUT0)); - - // try out optimized versions - switch (qbs) { -#define DISPATCH(QBS) \ - case QBS: \ - accumulate_q_4step(ntotal2, nsq, codes, LUT0, res, scaler); \ - return; - DISPATCH(0x3333); // 12 - DISPATCH(0x2333); // 11 - DISPATCH(0x2233); // 10 - DISPATCH(0x333); // 9 - DISPATCH(0x2223); // 9 - DISPATCH(0x233); // 8 - DISPATCH(0x1223); // 8 - DISPATCH(0x223); // 7 - DISPATCH(0x34); // 7 - DISPATCH(0x133); // 7 - DISPATCH(0x6); // 6 - DISPATCH(0x33); // 6 - DISPATCH(0x123); // 6 - DISPATCH(0x222); // 6 - DISPATCH(0x23); // 5 - DISPATCH(0x5); // 5 - DISPATCH(0x13); // 4 - DISPATCH(0x22); // 4 - DISPATCH(0x4); // 4 - DISPATCH(0x3); // 3 - DISPATCH(0x21); // 3 - DISPATCH(0x2); // 2 - DISPATCH(0x1); // 1 -#undef DISPATCH - } - - // default implementation where qbs is not known at compile time - - for (size_t j0 = 0; j0 < ntotal2; j0 += 32) { - const uint8_t* LUT = LUT0; - int qi = qbs; - int i0 = 0; - while (qi) { - int nq = qi & 15; - qi >>= 4; - res.set_block_origin(i0, j0); -#define DISPATCH(NQ) \ - case NQ: \ - kernel_accumulate_block( \ - nsq, codes, LUT, res, scaler); \ - break - switch (nq) { - DISPATCH(1); - DISPATCH(2); - DISPATCH(3); - DISPATCH(4); -#undef DISPATCH - default: - FAISS_THROW_FMT("accumulate nq=%d not instantiated", nq); - } - i0 += nq; - LUT += nq * nsq * 16; - } - codes += 32 * nsq / 2; - } -} - -struct Run_pq4_accumulate_loop_qbs { - template - void f(ResultHandler& res, - int qbs, - size_t nb, - int nsq, - const uint8_t* codes, - const uint8_t* LUT, - const NormTableScaler* scaler) { - if (scaler) { - pq4_accumulate_loop_qbs_fixed_scaler( - qbs, nb, nsq, codes, LUT, res, *scaler); - } else { - DummyScaler dummy; - pq4_accumulate_loop_qbs_fixed_scaler( - qbs, nb, nsq, codes, LUT, res, dummy); - } - } -}; - -} // namespace - -void pq4_accumulate_loop_qbs( - int qbs, - size_t nb, - int nsq, - const uint8_t* codes, - const uint8_t* LUT, - SIMDResultHandler& res, - const NormTableScaler* scaler) { - Run_pq4_accumulate_loop_qbs consumer; - dispatch_SIMDResultHandler(res, consumer, qbs, nb, nsq, codes, LUT, scaler); -} - -/*************************************************************** - * Packing functions - ***************************************************************/ - -int pq4_qbs_to_nq(int qbs) { - int i0 = 0; - int qi = qbs; - while (qi) { - int nq = qi & 15; - qi >>= 4; - i0 += nq; - } - return i0; -} - -void accumulate_to_mem( - int nq, - size_t ntotal2, - int nsq, - const uint8_t* codes, - const uint8_t* LUT, - uint16_t* accu) { - FAISS_THROW_IF_NOT(ntotal2 % 32 == 0); - StoreResultHandler handler(accu, ntotal2); - DummyScaler scaler; - accumulate(nq, ntotal2, nsq, codes, LUT, handler, scaler); -} - -int pq4_preferred_qbs(int n) { - // from timmings in P141901742, P141902828 - static int map[12] = { - 0, 1, 2, 3, 0x13, 0x23, 0x33, 0x223, 0x233, 0x333, 0x2233, 0x2333}; - if (n <= 11) { - return map[n]; - } else if (n <= 24) { - // override qbs: all first stages with 3 steps - // then 1 stage with the rest - int nbit = 4 * (n / 3); // nbits with only 3s - int qbs = 0x33333333 & ((1 << nbit) - 1); - qbs |= (n % 3) << nbit; - return qbs; - } else { - FAISS_THROW_FMT("number of queries %d too large", n); - } -} +#include } // namespace faiss diff --git a/faiss/impl/pq4_fast_scan.cpp b/faiss/impl/pq_4bit/pq4_fast_scan.cpp similarity index 62% rename from faiss/impl/pq4_fast_scan.cpp rename to faiss/impl/pq_4bit/pq4_fast_scan.cpp index 5d7e2a4efd..718b725e2e 100644 --- a/faiss/impl/pq4_fast_scan.cpp +++ b/faiss/impl/pq_4bit/pq4_fast_scan.cpp @@ -7,10 +7,18 @@ #include #include -#include -#include +#include +#include -#include +#ifdef __x86_64__ +#ifdef __AVX2__ +#error "this should not be compiled with AVX2" +#endif +#endif + +#include + +#include namespace faiss { @@ -320,4 +328,182 @@ int pq4_pack_LUT_qbs_q_map( return i0; } +/*************************************************************** + * Packing functions + ***************************************************************/ + +int pq4_qbs_to_nq(int qbs) { + int i0 = 0; + int qi = qbs; + while (qi) { + int nq = qi & 15; + qi >>= 4; + i0 += nq; + } + return i0; +} + +void accumulate_to_mem( + int nq, + size_t ntotal2, + int nsq, + const uint8_t* codes, + const uint8_t* LUT, + uint16_t* accu) { + FAISS_THROW_IF_NOT(ntotal2 % 32 == 0); + StoreResultHandler handler(accu, ntotal2); + DummyScaler scaler; + accumulate(nq, ntotal2, nsq, codes, LUT, handler, scaler); +} + +int pq4_preferred_qbs(int n) { + // from timmings in P141901742, P141902828 + static int map[12] = { + 0, 1, 2, 3, 0x13, 0x23, 0x33, 0x223, 0x233, 0x333, 0x2233, 0x2333}; + if (n <= 11) { + return map[n]; + } else if (n <= 24) { + // override qbs: all first stages with 3 steps + // then 1 stage with the rest + int nbit = 4 * (n / 3); // nbits with only 3s + int qbs = 0x33333333 & ((1 << nbit) - 1); + qbs |= (n % 3) << nbit; + return qbs; + } else { + FAISS_THROW_FMT("number of queries %d too large", n); + } +} + +/**************************** Dispatching */ + +template <> +PQ4CodeScanner* make_pq4_scanner(KNN_ARGS_LIST) { + return make_pq4_scanner_1(KNN_ARGS_LIST_2); +} + +template <> +PQ4CodeScanner* make_pq4_scanner(KNN_ARGS_LIST) { + return make_pq4_scanner_1(KNN_ARGS_LIST_2); +} + +template <> +PQ4CodeScanner* make_pq4_scanner(RRES_ARGS_LIST) { + return make_pq4_scanner_1(RRES_ARGS_LIST_2); +} + +template <> +PQ4CodeScanner* make_pq4_scanner(PRES_ARGS_LIST) { + return make_pq4_scanner_1(PRES_ARGS_LIST_2); +} + +template +PQ4CodeScanner* make_knn_scanner( + bool is_max, + int ns, + bool ur, + idx_t nq, + idx_t ntotal, + idx_t k, + float* dis, + idx_t* ids, + const IDSelector* sel) { +#ifdef COMPILE_SIMD_AVX512F + if (SIMDConfig::level == SIMDLevel::AVX512F) { + return make_pq4_scanner( + is_max, ns, ur, nq, ntotal, k, dis, ids, nullptr); + } else +#endif +#ifdef COMPILE_SIMD_AVX2 + if (SIMDConfig::level == SIMDLevel::AVX2) { + return make_pq4_scanner( + is_max, ns, ur, nq, ntotal, k, dis, ids, nullptr); + } else +#endif + { + return make_pq4_scanner( + is_max, ns, ur, nq, ntotal, k, dis, ids, nullptr); + } +} + +PQ4CodeScanner* pq4_make_flat_knn_handler( + bool is_max, + bool ur, + idx_t nq, + idx_t k, + idx_t ntotal, + float* dis, + idx_t* ids, + int ns, + const float* normalizers, + bool disable) { + PQ4CodeScanner* res = make_knn_scanner( + is_max, ns, ur, nq, ntotal, k, dis, ids, nullptr); + res->disable = disable; + res->normalizers = normalizers; + return res; +} + +PQ4CodeScanner* pq4_make_ivf_knn_handler( + bool is_max, + bool use_reservoir, + idx_t nq, + idx_t k, + float* dis, + idx_t* ids, + int norm_scale, + const IDSelector* sel) { + return make_knn_scanner( + is_max, norm_scale, use_reservoir, nq, 0, k, dis, ids, sel); +} + +PQ4CodeScanner* pq4_make_ivf_range_handler( + bool is_max, + RangeSearchResult& rres, + float radius, + int norm_scale, + const IDSelector* sel) { +#ifdef COMPILE_SIMD_AVX512F + if (SIMDConfig::level == SIMDLevel::AVX512F) { + return make_pq4_scanner( + is_max, norm_scale, &rres, radius, 0, sel); + } else +#endif +#ifdef COMPILE_SIMD_AVX2 + if (SIMDConfig::level == SIMDLevel::AVX2) { + return make_pq4_scanner( + is_max, norm_scale, &rres, radius, 0, sel); + } else +#endif + { + return make_pq4_scanner( + is_max, norm_scale, &rres, radius, 0, sel); + } +} + +PQ4CodeScanner* pq4_make_ivf_partial_range_handler( + bool is_max, + RangeSearchPartialResult& pres, + float radius, + idx_t i0, + idx_t i1, + int norm_scale, + const IDSelector* sel) { +#ifdef COMPILE_SIMD_AVX512F + if (SIMDConfig::level == SIMDLevel::AVX512F) { + return make_pq4_scanner( + is_max, norm_scale, &pres, radius, 0, i0, i1, sel); + } else +#endif +#ifdef COMPILE_SIMD_AVX2 + if (SIMDConfig::level == SIMDLevel::AVX2) { + return make_pq4_scanner( + is_max, norm_scale, &pres, radius, 0, i0, i1, sel); + } else +#endif + { + return make_pq4_scanner( + is_max, norm_scale, &pres, radius, 0, i0, i1, sel); + } +} + } // namespace faiss diff --git a/faiss/impl/pq4_fast_scan.h b/faiss/impl/pq_4bit/pq4_fast_scan.h similarity index 67% rename from faiss/impl/pq4_fast_scan.h rename to faiss/impl/pq_4bit/pq4_fast_scan.h index ccb084e7a5..9192f43fb4 100644 --- a/faiss/impl/pq4_fast_scan.h +++ b/faiss/impl/pq_4bit/pq4_fast_scan.h @@ -11,6 +11,7 @@ #include #include +#include /** PQ4 SIMD packing and accumulation functions * @@ -24,8 +25,49 @@ namespace faiss { -struct NormTableScaler; -struct SIMDResultHandler; +/* Result handler that will return float resutls eventually */ +struct PQ4CodeScanner { + size_t nq; // number of queries + size_t ntotal; // ignore excess elements after ntotal + + bool disable = false; // for benchmarking + int norm_scale = -1; // do the codes include 2x4 bits of scale? + + /// these fields are used for the IVF variants (with_id_map=true) + const idx_t* id_map = nullptr; // map offset in invlist to vector id + const int* q_map = nullptr; // map q to global query + const uint16_t* dbias = + nullptr; // table of biases to add to each query (for IVF L2 search) + const float* normalizers = nullptr; // size 2 * nq, to convert to float + + PQ4CodeScanner(size_t nq, size_t ntotal) : nq(nq), ntotal(ntotal) {} + + virtual void accumulate_loop( + int nq, + size_t nb, + int bbs, + int nsq, + const uint8_t* codes, + const uint8_t* LUT) = 0; + + virtual void accumulate_loop_qbs( + int qbs, + size_t nb, + int nsq, + const uint8_t* codes, + const uint8_t* LUT) = 0; + + virtual void begin(const float* norms) { + normalizers = norms; + } + + // called at end of search to convert int16 distances to float, before + // normalizers are deallocated + virtual void end() { + normalizers = nullptr; + } + virtual ~PQ4CodeScanner() {} +}; /** Pack codes for consumption by the SIMD kernels. * The unused bytes are set to 0. @@ -110,26 +152,6 @@ struct CodePackerPQ4 : CodePacker { */ void pq4_pack_LUT(int nq, int nsq, const uint8_t* src, uint8_t* dest); -/** Loop over database elements and accumulate results into result handler - * - * @param nq number of queries - * @param nb number of database elements - * @param bbs size of database blocks (multiple of 32) - * @param nsq number of sub-quantizers (muliple of 2) - * @param codes packed codes array - * @param LUT packed look-up table - * @param scaler scaler to scale the encoded norm - */ -void pq4_accumulate_loop( - int nq, - size_t nb, - int bbs, - int nsq, - const uint8_t* codes, - const uint8_t* LUT, - SIMDResultHandler& res, - const NormTableScaler* scaler); - /* qbs versions, supported only for bbs=32. * * The kernel function runs the kernel for *several* query blocks @@ -170,25 +192,6 @@ int pq4_pack_LUT_qbs_q_map( const int* q_map, uint8_t* dest); -/** Run accumulation loop. - * - * @param qbs 4-bit encoded number of queries - * @param nb number of database codes (mutliple of bbs) - * @param nsq number of sub-quantizers - * @param codes encoded database vectors (packed) - * @param LUT look-up table (packed) - * @param res call-back for the resutls - * @param scaler scaler to scale the encoded norm - */ -void pq4_accumulate_loop_qbs( - int qbs, - size_t nb, - int nsq, - const uint8_t* codes, - const uint8_t* LUT, - SIMDResultHandler& res, - const NormTableScaler* scaler = nullptr); - /** Wrapper of pq4_accumulate_loop_qbs using simple StoreResultHandler * and DummyScaler * @@ -207,4 +210,48 @@ void accumulate_to_mem( const uint8_t* LUT, uint16_t* accu); +PQ4CodeScanner* pq4_make_flat_knn_handler( + bool is_max, + bool use_reservoir, + idx_t nq, + idx_t k, + idx_t ntotal, + float* distances, + idx_t* labels, + int norm_scale, + const float* normalizers = nullptr, + bool disable = false); + +struct IDSelector; + +PQ4CodeScanner* pq4_make_ivf_knn_handler( + bool is_max, + bool use_reservoir, + idx_t nq, + idx_t k, + float* distances, + idx_t* labels, + int norm_scale, + const IDSelector* sel); + +struct RangeSearchResult; + +PQ4CodeScanner* pq4_make_ivf_range_handler( + bool is_max, + RangeSearchResult& rres, + float radius, + int norm_scale, + const IDSelector* sel); + +struct RangeSearchPartialResult; + +PQ4CodeScanner* pq4_make_ivf_partial_range_handler( + bool is_max, + RangeSearchPartialResult& pres, + float radius, + idx_t i0, + idx_t i1, + int norm_scale, + const IDSelector* sel); + } // namespace faiss diff --git a/faiss/impl/simd_result_handlers.h b/faiss/impl/pq_4bit/simd_result_handlers.h similarity index 72% rename from faiss/impl/simd_result_handlers.h rename to faiss/impl/pq_4bit/simd_result_handlers.h index baa640d865..0ec0e4a930 100644 --- a/faiss/impl/simd_result_handlers.h +++ b/faiss/impl/pq_4bit/simd_result_handlers.h @@ -12,80 +12,48 @@ #include #include +#include #include #include #include #include #include +#include + #include #include -/** This file contains callbacks for kernels that compute distances. +/** This file contains a class that collects top-k or range search results from + * computed distances in simd. + * + * Note that the simd16int16 and friends are not expicilty templated on + * SIMDLevel, but code compiled with different simd optimization flags will be + * incompatible (and crash at runtime). */ namespace faiss { -struct SIMDResultHandler { - // used to dispatch templates - bool is_CMax = false; - uint8_t sizeof_ids = 0; - bool with_fields = false; - - /** called when 32 distances are computed and provided in two - * simd16uint16. (q, b) indicate which entry it is in the block. */ - virtual void handle( - size_t q, - size_t b, - simd16uint16 d0, - simd16uint16 d1) = 0; - - /// set the sub-matrix that is being computed - virtual void set_block_origin(size_t i0, size_t j0) = 0; - - virtual ~SIMDResultHandler() {} -}; - -/* Result handler that will return float resutls eventually */ -struct SIMDResultHandlerToFloat : SIMDResultHandler { - size_t nq; // number of queries - size_t ntotal; // ignore excess elements after ntotal - - /// these fields are used mainly for the IVF variants (with_id_map=true) - const idx_t* id_map = nullptr; // map offset in invlist to vector id - const int* q_map = nullptr; // map q to global query - const uint16_t* dbias = - nullptr; // table of biases to add to each query (for IVF L2 search) - const float* normalizers = nullptr; // size 2 * nq, to convert - - SIMDResultHandlerToFloat(size_t nq, size_t ntotal) - : nq(nq), ntotal(ntotal) {} - - virtual void begin(const float* norms) { - normalizers = norms; - } - - // called at end of search to convert int16 distances to float, before - // normalizers are deallocated - virtual void end() { - normalizers = nullptr; - } -}; - -FAISS_API extern bool simd_result_handlers_accept_virtual; +/*********************** From here on we need to know the SIMDLevel at compile + * time */ namespace simd_result_handlers { /** Dummy structure that just computes a chqecksum on results * (to avoid the computation to be optimized away) */ -struct DummyResultHandler : SIMDResultHandler { +template +struct DummyResultHandler { + static constexpr SIMDLevel SL = SL_IN; + size_t cs = 0; - void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { + DummyResultHandler() {} + + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) { cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0(); } - void set_block_origin(size_t, size_t) final {} + void set_block_origin(size_t, size_t) {} ~DummyResultHandler() {} }; @@ -94,7 +62,9 @@ struct DummyResultHandler : SIMDResultHandler { * * j0 is the current upper-left block of the matrix */ -struct StoreResultHandler : SIMDResultHandler { +template +struct StoreResultHandler { + static constexpr SIMDLevel SL = SL_IN; uint16_t* data; size_t ld; // total number of columns size_t i0 = 0; @@ -102,30 +72,32 @@ struct StoreResultHandler : SIMDResultHandler { StoreResultHandler(uint16_t* data, size_t ld) : data(data), ld(ld) {} - void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) { size_t ofs = (q + i0) * ld + j0 + b * 32; d0.store(data + ofs); d1.store(data + ofs + 16); } - void set_block_origin(size_t i0_in, size_t j0_in) final { + void set_block_origin(size_t i0_in, size_t j0_in) { this->i0 = i0_in; this->j0 = j0_in; } }; /** stores results in fixed-size matrix. */ -template -struct FixedStorageHandler : SIMDResultHandler { +template +struct FixedStorageHandler { simd16uint16 dis[NQ][BB]; int i0 = 0; - void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { + FixedStorageHandler() {} + + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) { dis[q + i0][2 * b] = d0; dis[q + i0][2 * b + 1] = d1; } - void set_block_origin(size_t i0_in, size_t j0_in) final { + void set_block_origin(size_t i0_in, size_t j0_in) { this->i0 = i0_in; assert(j0_in == 0); } @@ -143,11 +115,11 @@ struct FixedStorageHandler : SIMDResultHandler { }; /** Result handler that compares distances to check if they need to be kept */ -template -struct ResultHandlerCompare : SIMDResultHandlerToFloat { - using TI = typename C::TI; +template +struct ResultHandlerCompare : PQ4CodeScanner { + static constexpr SIMDLevel SL = SL_IN; - bool disable = false; + using TI = typename C::TI; int64_t i0 = 0; // query origin int64_t j0 = 0; // db origin @@ -155,13 +127,9 @@ struct ResultHandlerCompare : SIMDResultHandlerToFloat { const IDSelector* sel; ResultHandlerCompare(size_t nq, size_t ntotal, const IDSelector* sel_in) - : SIMDResultHandlerToFloat(nq, ntotal), sel{sel_in} { - this->is_CMax = C::is_max; - this->sizeof_ids = sizeof(typename C::TI); - this->with_fields = with_id_map; - } + : PQ4CodeScanner(nq, ntotal), sel{sel_in} {} - void set_block_origin(size_t i0_in, size_t j0_in) final { + void set_block_origin(size_t i0_in, size_t j0_in) { this->i0 = i0_in; this->j0 = j0_in; } @@ -225,11 +193,11 @@ struct ResultHandlerCompare : SIMDResultHandlerToFloat { }; /** Special version for k=1 */ -template -struct SingleResultHandler : ResultHandlerCompare { +template +struct SingleResultHandler : ResultHandlerCompare { using T = typename C::T; using TI = typename C::TI; - using RHC = ResultHandlerCompare; + using RHC = ResultHandlerCompare; using RHC::normalizers; std::vector idis; @@ -249,7 +217,7 @@ struct SingleResultHandler : ResultHandlerCompare { } } - void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) { if (this->disable) { return; } @@ -307,11 +275,12 @@ struct SingleResultHandler : ResultHandlerCompare { }; /** Structure that collects results in a min- or max-heap */ -template -struct HeapHandler : ResultHandlerCompare { +template +struct HeapHandler : ResultHandlerCompare { + static const SIMDLevel SL = SL_IN; using T = typename C::T; using TI = typename C::TI; - using RHC = ResultHandlerCompare; + using RHC = ResultHandlerCompare; using RHC::normalizers; std::vector idis; @@ -337,7 +306,7 @@ struct HeapHandler : ResultHandlerCompare { heap_heapify(k * nq, idis.data(), iids.data()); } - void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) { if (this->disable) { return; } @@ -416,11 +385,12 @@ struct HeapHandler : ResultHandlerCompare { * reached. Then a partition sort is used to update the threshold. */ /** Handler built from several ReservoirTopN (one per query) */ -template -struct ReservoirHandler : ResultHandlerCompare { +template +struct ReservoirHandler : ResultHandlerCompare { + static const SIMDLevel SL = SL_IN; using T = typename C::T; using TI = typename C::TI; - using RHC = ResultHandlerCompare; + using RHC = ResultHandlerCompare; using RHC::normalizers; size_t capacity; // rounded up to multiple of 16 @@ -457,7 +427,7 @@ struct ReservoirHandler : ResultHandlerCompare { } } - void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) { if (this->disable) { return; } @@ -539,11 +509,11 @@ struct ReservoirHandler : ResultHandlerCompare { * have to be scaled using the scaler. */ -template -struct RangeHandler : ResultHandlerCompare { +template +struct RangeHandler : ResultHandlerCompare { using T = typename C::T; using TI = typename C::TI; - using RHC = ResultHandlerCompare; + using RHC = ResultHandlerCompare; using RHC::normalizers; using RHC::nq; @@ -563,11 +533,11 @@ struct RangeHandler : ResultHandlerCompare { std::vector triplets; RangeHandler( - RangeSearchResult& rres, + RangeSearchResult* rres, float radius, size_t ntotal, const IDSelector* sel_in) - : RHC(rres.nq, ntotal, sel_in), rres(rres), radius(radius) { + : RHC(rres->nq, ntotal, sel_in), rres(*rres), radius(radius) { thresholds.resize(nq); n_per_query.resize(nq + 1); } @@ -580,7 +550,7 @@ struct RangeHandler : ResultHandlerCompare { } } - void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) { if (this->disable) { return; } @@ -645,25 +615,24 @@ struct RangeHandler : ResultHandlerCompare { #ifndef SWIG // handler for a subset of queries -template -struct PartialRangeHandler : RangeHandler { +template +struct PartialRangeHandler : RangeHandler { using T = typename C::T; using TI = typename C::TI; - using RHC = RangeHandler; + using RHC = RangeHandler; using RHC::normalizers; using RHC::nq, RHC::q0, RHC::triplets, RHC::n_per_query; RangeSearchPartialResult& pres; PartialRangeHandler( - RangeSearchPartialResult& pres, + RangeSearchPartialResult* pres, float radius, size_t ntotal, size_t q0, size_t q1, const IDSelector* sel_in) - : RangeHandler(*pres.res, radius, ntotal, sel_in), - pres(pres) { + : RHC(pres->res, radius, ntotal, sel_in), pres(*pres) { nq = q1 - q0; this->q0 = q0; } @@ -706,82 +675,6 @@ struct PartialRangeHandler : RangeHandler { #endif -/******************************************************************************** - * Dynamic dispatching function. The consumer should have a templatized method f - * that will be replaced with the actual SIMDResultHandler that is determined - * dynamically. - */ - -template -void dispatch_SIMDResultHandler_fixedCW( - SIMDResultHandler& res, - Consumer& consumer, - Types... args) { - if (auto resh = dynamic_cast*>(&res)) { - consumer.template f>(*resh, args...); - } else if (auto resh_2 = dynamic_cast*>(&res)) { - consumer.template f>(*resh_2, args...); - } else if (auto resh_2 = dynamic_cast*>(&res)) { - consumer.template f>(*resh_2, args...); - } else { // generic handler -- will not be inlined - FAISS_THROW_IF_NOT_FMT( - simd_result_handlers_accept_virtual, - "Running vitrual handler for %s", - typeid(res).name()); - consumer.template f(res, args...); - } -} - -template -void dispatch_SIMDResultHandler_fixedC( - SIMDResultHandler& res, - Consumer& consumer, - Types... args) { - if (res.with_fields) { - dispatch_SIMDResultHandler_fixedCW(res, consumer, args...); - } else { - dispatch_SIMDResultHandler_fixedCW(res, consumer, args...); - } -} - -template -void dispatch_SIMDResultHandler( - SIMDResultHandler& res, - Consumer& consumer, - Types... args) { - if (res.sizeof_ids == 0) { - if (auto resh = dynamic_cast(&res)) { - consumer.template f(*resh, args...); - } else if (auto resh_2 = dynamic_cast(&res)) { - consumer.template f(*resh_2, args...); - } else { // generic path - FAISS_THROW_IF_NOT_FMT( - simd_result_handlers_accept_virtual, - "Running vitrual handler for %s", - typeid(res).name()); - consumer.template f(res, args...); - } - } else if (res.sizeof_ids == sizeof(int)) { - if (res.is_CMax) { - dispatch_SIMDResultHandler_fixedC>( - res, consumer, args...); - } else { - dispatch_SIMDResultHandler_fixedC>( - res, consumer, args...); - } - } else if (res.sizeof_ids == sizeof(int64_t)) { - if (res.is_CMax) { - dispatch_SIMDResultHandler_fixedC>( - res, consumer, args...); - } else { - dispatch_SIMDResultHandler_fixedC>( - res, consumer, args...); - } - } else { - FAISS_THROW_FMT("Unknown id size %d", res.sizeof_ids); - } -} - } // namespace simd_result_handlers } // namespace faiss diff --git a/faiss/impl/pq_code_distance/code_distance-avx2.cpp b/faiss/impl/pq_code_distance/code_distance-avx2.cpp new file mode 100644 index 0000000000..8fb80b822e --- /dev/null +++ b/faiss/impl/pq_code_distance/code_distance-avx2.cpp @@ -0,0 +1,486 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=78782 +#if defined(__GNUC__) && __GNUC__ < 9 +#define _mm_loadu_si64(x) (_mm_loadl_epi64((__m128i_u*)x)) +#endif + +namespace { + +inline float horizontal_sum(const __m128 v) { + const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); + const __m128 v1 = _mm_add_ps(v, v0); + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + const __m128 v3 = _mm_add_ps(v1, v2); + return _mm_cvtss_f32(v3); +} + +// Computes a horizontal sum over an __m256 register +inline float horizontal_sum(const __m256 v) { + const __m128 v0 = + _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + return horizontal_sum(v0); +} + +// processes a single code for M=4, ksub=256, nbits=8 +float inline distance_single_code_avx2_pqdecoder8_m4( + // precomputed distances, layout (4, 256) + const float* sim_table, + const uint8_t* code) { + float result = 0; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + const __m128i vksub = _mm_set1_epi32(ksub); + __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); + offsets_0 = _mm_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m128 partialSum; + + // load 4 uint8 values + const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code)); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m128i idx1 = _mm_cvtepu8_epi32(mm1); + + // add offsets + const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m128 collected = + _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSum = collected; + } + + // horizontal sum for partialSum + result = horizontal_sum(partialSum); + return result; +} + +// processes a single code for M=8, ksub=256, nbits=8 +float inline distance_single_code_avx2_pqdecoder8_m8( + // precomputed distances, layout (8, 256) + const float* sim_table, + const uint8_t* code) { + float result = 0; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSum; + + // load 8 uint8 values + const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); + + // add offsets + const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = + _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSum = collected; + } + + // horizontal sum for partialSum + result = horizontal_sum(partialSum); + return result; +} + +// processes four codes for M=4, ksub=256, nbits=8 +inline void distance_four_codes_avx2_pqdecoder8_m4( + // precomputed distances, layout (4, 256) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + constexpr intptr_t N = 4; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + // process 8 values + const __m128i vksub = _mm_set1_epi32(ksub); + __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); + offsets_0 = _mm_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m128 partialSums[N]; + + // load 4 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0)); + mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1)); + mm1[2] = _mm_cvtsi32_si128(*((const int32_t*)code2)); + mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3)); + + for (intptr_t j = 0; j < N; j++) { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); + + // gather 4 values, similar to 4 operations of tab[idx] + __m128 collected = + _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = collected; + } + + // horizontal sum for partialSum + result0 = horizontal_sum(partialSums[0]); + result1 = horizontal_sum(partialSums[1]); + result2 = horizontal_sum(partialSums[2]); + result3 = horizontal_sum(partialSums[3]); +} + +// processes four codes for M=8, ksub=256, nbits=8 +inline void distance_four_codes_avx2_pqdecoder8_m8( + // precomputed distances, layout (8, 256) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + constexpr intptr_t N = 4; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + // process 8 values + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSums[N]; + + // load 8 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si64((const __m128i_u*)code0); + mm1[1] = _mm_loadu_si64((const __m128i_u*)code1); + mm1[2] = _mm_loadu_si64((const __m128i_u*)code2); + mm1[3] = _mm_loadu_si64((const __m128i_u*)code3); + + for (intptr_t j = 0; j < N; j++) { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = + _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = collected; + } + + // horizontal sum for partialSum + result0 = horizontal_sum(partialSums[0]); + result1 = horizontal_sum(partialSums[1]); + result2 = horizontal_sum(partialSums[2]); + result3 = horizontal_sum(partialSums[3]); +} + +} // namespace + +namespace faiss { + +template <> +struct PQCodeDistance { + float distance_single_code( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + const uint8_t* code) { + if (M == 4) { + return distance_single_code_avx2_pqdecoder8_m4(sim_table, code); + } + if (M == 8) { + return distance_single_code_avx2_pqdecoder8_m8(sim_table, code); + } + + float result = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSum = _mm256_setzero_ps(); + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + const __m128i mm1 = + _mm_loadu_si128((const __m128i_u*)(code + m)); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + tab += ksub * 8; + + // collect partial sums + partialSum = _mm256_add_ps(partialSum, collected); + } + + // move high 8 uint8 to low ones + const __m128i mm2 = + _mm_unpackhi_epi64(mm1, _mm_setzero_si128()); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + tab += ksub * 8; + + // collect partial sums + partialSum = _mm256_add_ps(partialSum, collected); + } + } + + // horizontal sum for partialSum + result += horizontal_sum(partialSum); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder(code + m, nbits); + + for (; m < M; m++) { + result += tab[decoder.decode()]; + tab += ksub; + } + } + + return result; + } + + // Combines 4 operations of distance_single_code() + void distance_four_codes( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + if (M == 4) { + distance_four_codes_avx2_pqdecoder8_m4( + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); + return; + } + if (M == 8) { + distance_four_codes_avx2_pqdecoder8_m8( + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); + return; + } + + result0 = 0; + result1 = 0; + result2 = 0; + result3 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 4; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm256_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); + mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); + mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm256_add_ps(partialSums[j], collected); + } + tab += ksub * 8; + + // process next 8 codes + for (intptr_t j = 0; j < N; j++) { + // move high 8 uint8 to low ones + const __m128i mm2 = + _mm_unpackhi_epi64(mm1[j], _mm_setzero_si128()); + + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm256_add_ps(partialSums[j], collected); + } + + tab += ksub * 8; + } + + // horizontal sum for partialSum + result0 += horizontal_sum(partialSums[0]); + result1 += horizontal_sum(partialSums[1]); + result2 += horizontal_sum(partialSums[2]); + result3 += horizontal_sum(partialSums[3]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + PQDecoder8 decoder1(code1 + m, nbits); + PQDecoder8 decoder2(code2 + m, nbits); + PQDecoder8 decoder3(code3 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + result1 += tab[decoder1.decode()]; + result2 += tab[decoder2.decode()]; + result3 += tab[decoder3.decode()]; + tab += ksub; + } + } + } +}; + +// explicit template instanciations +// template struct PQCodeDistance; + +// these two will automatically use the generic implementation +template struct PQCodeDistance; +template struct PQCodeDistance; + +} // namespace faiss diff --git a/faiss/impl/pq_code_distance/code_distance-avx512.cpp b/faiss/impl/pq_code_distance/code_distance-avx512.cpp new file mode 100644 index 0000000000..5032ca6090 --- /dev/null +++ b/faiss/impl/pq_code_distance/code_distance-avx512.cpp @@ -0,0 +1,199 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +// According to experiments, the AVX-512 version may be SLOWER than +// the AVX2 version, which is somewhat unexpected. +// This version is not used for now, but it may be used later. +// +// TODO: test for AMD CPUs. + +namespace faiss { + +template <> +struct PQCodeDistance { + float distance_single_code( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + const uint8_t* code0) { + float result0 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 1; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m512i vksub = _mm512_set1_epi32(ksub); + __m512i offsets_0 = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m512 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm512_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m512i indices_to_read_from = + _mm512_add_epi32(idx1, offsets_0); + + // gather 16 values, similar to 16 operations of tab[idx] + __m512 collected = _mm512_i32gather_ps( + indices_to_read_from, tab, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm512_add_ps(partialSums[j], collected); + } + tab += ksub * 16; + } + + // horizontal sum for partialSum + result0 += _mm512_reduce_add_ps(partialSums[0]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + tab += ksub; + } + } + + return result0; + } + + void distance_four_codes_avx512( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + result0 = 0; + result1 = 0; + result2 = 0; + result3 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 4; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m512i vksub = _mm512_set1_epi32(ksub); + __m512i offsets_0 = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m512 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm512_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); + mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); + mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m512i indices_to_read_from = + _mm512_add_epi32(idx1, offsets_0); + + // gather 16 values, similar to 16 operations of tab[idx] + __m512 collected = _mm512_i32gather_ps( + indices_to_read_from, tab, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm512_add_ps(partialSums[j], collected); + } + tab += ksub * 16; + } + + // horizontal sum for partialSum + result0 += _mm512_reduce_add_ps(partialSums[0]); + result1 += _mm512_reduce_add_ps(partialSums[1]); + result2 += _mm512_reduce_add_ps(partialSums[2]); + result3 += _mm512_reduce_add_ps(partialSums[3]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + PQDecoder8 decoder1(code1 + m, nbits); + PQDecoder8 decoder2(code2 + m, nbits); + PQDecoder8 decoder3(code3 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + result1 += tab[decoder1.decode()]; + result2 += tab[decoder2.decode()]; + result3 += tab[decoder3.decode()]; + tab += ksub; + } + } + } +}; + +// explicit template instanciations +// template struct PQCodeDistance; + +// these two will automatically use the generic implementation +template struct PQCodeDistance; +template struct PQCodeDistance; + +} // namespace faiss diff --git a/faiss/impl/pq_code_distance/code_distance-generic.cpp b/faiss/impl/pq_code_distance/code_distance-generic.cpp new file mode 100644 index 0000000000..892d4b216f --- /dev/null +++ b/faiss/impl/pq_code_distance/code_distance-generic.cpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace faiss { + +// explicit template instanciations +template struct PQCodeDistance; +template struct PQCodeDistance; +template struct PQCodeDistance; + +} // namespace faiss diff --git a/faiss/impl/code_distance/code_distance-sve.h b/faiss/impl/pq_code_distance/code_distance-sve.cpp similarity index 99% rename from faiss/impl/code_distance/code_distance-sve.h rename to faiss/impl/pq_code_distance/code_distance-sve.cpp index 82f7746be6..9638e864ff 100644 --- a/faiss/impl/code_distance/code_distance-sve.h +++ b/faiss/impl/pq_code_distance/code_distance-sve.cpp @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -#pragma once - #ifdef __ARM_FEATURE_SVE #include @@ -15,7 +13,7 @@ #include #include -#include +#include namespace faiss { diff --git a/faiss/impl/pq_code_distance/code_distance.h b/faiss/impl/pq_code_distance/code_distance.h new file mode 100644 index 0000000000..585890cb40 --- /dev/null +++ b/faiss/impl/pq_code_distance/code_distance.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +#include + +// This directory contains functions to compute a distance +// from a given PQ code to a query vector, given that the +// distances to a query vector for pq.M codebooks are precomputed. +// +// The code was originally the part of IndexIVFPQ.cpp. +// The baseline implementation can be found in +// code_distance-generic.h, distance_single_code_generic(). + +// The reason for this somewhat unusual structure is that +// custom implementations may need to fall off to generic +// implementation in certain cases. So, say, avx2 header file +// needs to reference the generic header file. This is +// why the names of the functions for custom implementations +// have this _generic or _avx2 suffix. + +namespace faiss { + +// definiton and default implementation +template +struct PQCodeDistance { + using PQDecoder = PQDecoderT; + + /// Returns the distance to a single code. + static float distance_single_code( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // the code + const uint8_t* code) { + PQDecoderT decoder(code, nbits); + const size_t ksub = 1 << nbits; + + const float* tab = sim_table; + float result = 0; + + for (size_t m = 0; m < M; m++) { + result += tab[decoder.decode()]; + tab += ksub; + } + + return result; + } + + /// Combines 4 operations of distance_single_code() + /// General-purpose version. + static void distance_four_codes( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + PQDecoderT decoder0(code0, nbits); + PQDecoderT decoder1(code1, nbits); + PQDecoderT decoder2(code2, nbits); + PQDecoderT decoder3(code3, nbits); + const size_t ksub = 1 << nbits; + + const float* tab = sim_table; + result0 = 0; + result1 = 0; + result2 = 0; + result3 = 0; + + for (size_t m = 0; m < M; m++) { + result0 += tab[decoder0.decode()]; + result1 += tab[decoder1.decode()]; + result2 += tab[decoder2.decode()]; + result3 += tab[decoder3.decode()]; + tab += ksub; + } + } +}; + +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/codecs.h b/faiss/impl/scalar_quantizer/codecs.h new file mode 100644 index 0000000000..b5c20d464b --- /dev/null +++ b/faiss/impl/scalar_quantizer/codecs.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace faiss { + +namespace scalar_quantizer { + +/******************************************************************* + * Codec: converts between values in [0, 1] and an index in a code + * array. The "i" parameter is the vector component index (not byte + * index). + */ + +template +struct Codec8bit {}; + +template +struct Codec4bit {}; + +template +struct Codec6bit {}; + +template <> +struct Codec8bit { + static FAISS_ALWAYS_INLINE void encode_component( + float x, + uint8_t* code, + int i) { + code[i] = (int)(255 * x); + } + + static FAISS_ALWAYS_INLINE float decode_component( + const uint8_t* code, + int i) { + return (code[i] + 0.5f) / 255.0f; + } +}; +template <> +struct Codec4bit { + static FAISS_ALWAYS_INLINE void encode_component( + float x, + uint8_t* code, + int i) { + code[i / 2] |= (int)(x * 15.0) << ((i & 1) << 2); + } + + static FAISS_ALWAYS_INLINE float decode_component( + const uint8_t* code, + int i) { + return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f; + } +}; + +template <> +struct Codec6bit { + static FAISS_ALWAYS_INLINE void encode_component( + float x, + uint8_t* code, + int i) { + int bits = (int)(x * 63.0); + code += (i >> 2) * 3; + switch (i & 3) { + case 0: + code[0] |= bits; + break; + case 1: + code[0] |= bits << 6; + code[1] |= bits >> 2; + break; + case 2: + code[1] |= bits << 4; + code[2] |= bits >> 4; + break; + case 3: + code[2] |= bits << 2; + break; + } + } + + static FAISS_ALWAYS_INLINE float decode_component( + const uint8_t* code, + int i) { + uint8_t bits; + code += (i >> 2) * 3; + switch (i & 3) { + case 0: + bits = code[0] & 0x3f; + break; + case 1: + bits = code[0] >> 6; + bits |= (code[1] & 0xf) << 2; + break; + case 2: + bits = code[1] >> 4; + bits |= (code[2] & 3) << 4; + break; + case 3: + bits = code[2] >> 2; + break; + } + return (bits + 0.5f) / 63.0f; + } +}; + +} // namespace scalar_quantizer +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/distance_computers.h b/faiss/impl/scalar_quantizer/distance_computers.h new file mode 100644 index 0000000000..c50f443ec4 --- /dev/null +++ b/faiss/impl/scalar_quantizer/distance_computers.h @@ -0,0 +1,277 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace faiss { + +namespace scalar_quantizer { + +/******************************************************************* + * Similarities: accumulates the element-wise similarities + *******************************************************************/ + +template +struct SimilarityL2 {}; + +template +struct SimilarityIP {}; + +template <> +struct SimilarityL2 { + static constexpr SIMDLevel SIMD_LEVEL = SIMDLevel::NONE; + static constexpr int simdwidth = 1; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + + explicit SimilarityL2(const float* y) : y(y) {} + + /******* scalar accumulator *******/ + + float accu; + + FAISS_ALWAYS_INLINE void begin() { + accu = 0; + yi = y; + } + + FAISS_ALWAYS_INLINE void add_component(float x) { + float tmp = *yi++ - x; + accu += tmp * tmp; + } + + FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) { + float tmp = x1 - x2; + accu += tmp * tmp; + } + + FAISS_ALWAYS_INLINE float result() { + return accu; + } +}; + +template <> +struct SimilarityIP { + static constexpr int simdwidth = 1; + static constexpr SIMDLevel SIMD_LEVEL = SIMDLevel::NONE; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + const float *y, *yi; + + float accu; + + explicit SimilarityIP(const float* y) : y(y) {} + + FAISS_ALWAYS_INLINE void begin() { + accu = 0; + yi = y; + } + + FAISS_ALWAYS_INLINE void add_component(float x) { + accu += *yi++ * x; + } + + FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) { + accu += x1 * x2; + } + + FAISS_ALWAYS_INLINE float result() { + return accu; + } +}; + +/******************************************************************* + * Distance computers: compute distances between a query and a code + *******************************************************************/ + +using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer; + +template +struct DCTemplate : SQDistanceComputer {}; + +template +struct DCTemplate : SQDistanceComputer { + using Sim = Similarity; + + Quantizer quant; + + DCTemplate(size_t d, const std::vector& trained) + : quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin(); + for (size_t i = 0; i < quant.d; i++) { + float xi = quant.reconstruct_component(code, i); + sim.add_component(xi); + } + return sim.result(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin(); + for (size_t i = 0; i < quant.d; i++) { + float x1 = quant.reconstruct_component(code1, i); + float x2 = quant.reconstruct_component(code2, i); + sim.add_component_2(x1, x2); + } + return sim.result(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; + +/******************************************************************* + * DistanceComputerByte: computes distances in the integer domain + *******************************************************************/ + +template +struct DistanceComputerByte : SQDistanceComputer {}; + +template +struct DistanceComputerByte : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + int accu = 0; + for (int i = 0; i < d; i++) { + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + accu += int(code1[i]) * code2[i]; + } else { + int diff = int(code1[i]) - code2[i]; + accu += diff * diff; + } + } + return accu; + } + + void set_query(const float* x) final { + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_code_distance(tmp.data(), code); + } +}; + +/******************************************************************* + * select_distance_computer: runtime selection of template + * specialization + *******************************************************************/ + +template +SQDistanceComputer* select_distance_computer( + QuantizerType qtype, + size_t d, + const std::vector& trained) { + constexpr SIMDLevel SL = Sim::SIMD_LEVEL; + constexpr QScaling NU = QScaling::NON_UNIFORM; + constexpr QScaling U = QScaling::UNIFORM; + switch (qtype) { + case ScalarQuantizer::QT_8bit_uniform: + return new DCTemplate, U, SL>, Sim, SL>( + d, trained); + + case ScalarQuantizer::QT_4bit_uniform: + return new DCTemplate, U, SL>, Sim, SL>( + d, trained); + + case ScalarQuantizer::QT_8bit: + return new DCTemplate, NU, SL>, Sim, SL>( + d, trained); + + case ScalarQuantizer::QT_6bit: + return new DCTemplate, NU, SL>, Sim, SL>( + d, trained); + + case ScalarQuantizer::QT_4bit: + return new DCTemplate, NU, SL>, Sim, SL>( + d, trained); + + case ScalarQuantizer::QT_fp16: + return new DCTemplate, Sim, SL>(d, trained); + + case ScalarQuantizer::QT_bf16: + return new DCTemplate, Sim, SL>(d, trained); + + case ScalarQuantizer::QT_8bit_direct: + return new DCTemplate, Sim, SL>(d, trained); + case ScalarQuantizer::QT_8bit_direct_signed: + return new DCTemplate, Sim, SL>( + d, trained); + } + FAISS_THROW_MSG("unknown qtype"); + return nullptr; +} + +template +SQDistanceComputer* select_distance_computer_1( + MetricType metric_type, + QuantizerType qtype, + size_t d, + const std::vector& trained) { + if (metric_type == METRIC_L2) { + return select_distance_computer>(qtype, d, trained); + } else if (metric_type == METRIC_INNER_PRODUCT) { + return select_distance_computer>(qtype, d, trained); + } else { + FAISS_THROW_MSG("unsuppored metric type"); + } +} + +// prevent implicit instantiation of the template +extern template SQDistanceComputer* select_distance_computer_1( + MetricType metric_type, + QuantizerType qtype, + size_t d, + const std::vector& trained); + +extern template SQDistanceComputer* select_distance_computer_1< + SIMDLevel::AVX512F>( + MetricType metric_type, + QuantizerType qtype, + size_t d, + const std::vector& trained); + +} // namespace scalar_quantizer +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/impl-avx2.cpp b/faiss/impl/scalar_quantizer/impl-avx2.cpp new file mode 100644 index 0000000000..a044ad9aaf --- /dev/null +++ b/faiss/impl/scalar_quantizer/impl-avx2.cpp @@ -0,0 +1,431 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#ifdef __F16C__ +#define USE_F16C +#else +#warning \ + "Cannot enable AVX optimizations in scalar quantizer if -mf16c is not set as well" +#endif + +namespace faiss { + +namespace scalar_quantizer { + +/****************************************** Specialization of codecs */ + +template <> +struct Codec8bit : Codec8bit { + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + const uint64_t c8 = *(uint64_t*)(code + i); + + const __m128i i8 = _mm_set1_epi64x(c8); + const __m256i i32 = _mm256_cvtepu8_epi32(i8); + const __m256 f8 = _mm256_cvtepi32_ps(i32); + const __m256 half_one_255 = _mm256_set1_ps(0.5f / 255.f); + const __m256 one_255 = _mm256_set1_ps(1.f / 255.f); + return simd8float32(_mm256_fmadd_ps(f8, one_255, half_one_255)); + } +}; + +template <> +struct Codec4bit : Codec4bit { + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + uint32_t c4 = *(uint32_t*)(code + (i >> 1)); + uint32_t mask = 0x0f0f0f0f; + uint32_t c4ev = c4 & mask; + uint32_t c4od = (c4 >> 4) & mask; + + // the 8 lower bytes of c8 contain the values + __m128i c8 = + _mm_unpacklo_epi8(_mm_set1_epi32(c4ev), _mm_set1_epi32(c4od)); + __m128i c4lo = _mm_cvtepu8_epi32(c8); + __m128i c4hi = _mm_cvtepu8_epi32(_mm_srli_si128(c8, 4)); + __m256i i8 = _mm256_castsi128_si256(c4lo); + i8 = _mm256_insertf128_si256(i8, c4hi, 1); + __m256 f8 = _mm256_cvtepi32_ps(i8); + __m256 half = _mm256_set1_ps(0.5f); + f8 = _mm256_add_ps(f8, half); + __m256 one_255 = _mm256_set1_ps(1.f / 15.f); + return simd8float32(_mm256_mul_ps(f8, one_255)); + } +}; + +template <> +struct Codec6bit : Codec6bit { + /* Load 6 bytes that represent 8 6-bit values, return them as a + * 8*32 bit vector register */ + static FAISS_ALWAYS_INLINE __m256i load6(const uint16_t* code16) { + const __m128i perm = _mm_set_epi8( + -1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0); + const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0); + + // load 6 bytes + __m128i c1 = + _mm_set_epi16(0, 0, 0, 0, 0, code16[2], code16[1], code16[0]); + + // put in 8 * 32 bits + __m128i c2 = _mm_shuffle_epi8(c1, perm); + __m256i c3 = _mm256_cvtepi16_epi32(c2); + + // shift and mask out useless bits + __m256i c4 = _mm256_srlv_epi32(c3, shifts); + __m256i c5 = _mm256_and_si256(_mm256_set1_epi32(63), c4); + return c5; + } + + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + // // Faster code for Intel CPUs or AMD Zen3+, just keeping it here + // // for the reference, maybe, it becomes used oned day. + // const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3); + // const uint32_t* data32 = (const uint32_t*)data16; + // const uint64_t val = *data32 + ((uint64_t)data16[2] << 32); + // const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL); + // const __m128i i8 = _mm_set1_epi64x(vext); + // const __m256i i32 = _mm256_cvtepi8_epi32(i8); + // const __m256 f8 = _mm256_cvtepi32_ps(i32); + // const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f); + // const __m256 one_255 = _mm256_set1_ps(1.f / 63.f); + // return _mm256_fmadd_ps(f8, one_255, half_one_255); + + __m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3)); + __m256 f8 = _mm256_cvtepi32_ps(i8); + // this could also be done with bit manipulations but it is + // not obviously faster + const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f); + const __m256 one_255 = _mm256_set1_ps(1.f / 63.f); + return simd8float32(_mm256_fmadd_ps(f8, one_255, half_one_255)); + } +}; + +/****************************************** Specialization of quantizers */ + +template +struct QuantizerT + : QuantizerT { + QuantizerT(size_t d, const std::vector& trained) + : QuantizerT( + d, + trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m256 xi = Codec::decode_8_components(code, i).f; + return simd8float32(_mm256_fmadd_ps( + xi, _mm256_set1_ps(this->vdiff), _mm256_set1_ps(this->vmin))); + } +}; + +template +struct QuantizerT + : QuantizerT { + QuantizerT(size_t d, const std::vector& trained) + : QuantizerT( + d, + trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m256 xi = Codec::decode_8_components(code, i).f; + return simd8float32(_mm256_fmadd_ps( + xi, + _mm256_loadu_ps(this->vdiff + i), + _mm256_loadu_ps(this->vmin + i))); + } +}; + +template <> +struct QuantizerFP16 : QuantizerFP16 { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i)); + return simd8float32(_mm256_cvtph_ps(codei)); + } +}; + +template <> +struct QuantizerBF16 : QuantizerBF16 { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i)); + __m256i code_256i = _mm256_cvtepu16_epi32(code_128i); + code_256i = _mm256_slli_epi32(code_256i, 16); + return simd8float32(_mm256_castsi256_ps(code_256i)); + } +}; + +template <> +struct Quantizer8bitDirect + : Quantizer8bitDirect { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 + __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 + return simd8float32(_mm256_cvtepi32_ps(y8)); // 8 * float32 + } +}; + +template <> +struct Quantizer8bitDirectSigned + : Quantizer8bitDirectSigned { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 + __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 + __m256i c8 = _mm256_set1_epi32(128); + __m256i z8 = _mm256_sub_epi32(y8, c8); // subtract 128 from all lanes + return simd8float32(_mm256_cvtepi32_ps(z8)); // 8 * float32 + } +}; + +/****************************************** Specialization of similarities */ + +template <> +struct SimilarityL2 { + static constexpr SIMDLevel SIMD_LEVEL = SIMDLevel::AVX2; + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + + explicit SimilarityL2(const float* y) : y(y) {} + simd8float32 accu8; + + FAISS_ALWAYS_INLINE void begin_8() { + accu8.clear(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) { + __m256 yiv = _mm256_loadu_ps(yi); + yi += 8; + __m256 tmp = _mm256_sub_ps(yiv, x.f); + accu8 = simd8float32(_mm256_fmadd_ps(tmp, tmp, accu8.f)); + } + + FAISS_ALWAYS_INLINE void add_8_components_2( + simd8float32 x, + simd8float32 y_2) { + __m256 tmp = _mm256_sub_ps(y_2.f, x.f); + accu8 = simd8float32(_mm256_fmadd_ps(tmp, tmp, accu8.f)); + } + + FAISS_ALWAYS_INLINE float result_8() { + const __m128 sum = _mm_add_ps( + _mm256_castps256_ps128(accu8.f), + _mm256_extractf128_ps(accu8.f, 1)); + const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); + const __m128 v1 = _mm_add_ps(sum, v0); + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + const __m128 v3 = _mm_add_ps(v1, v2); + return _mm_cvtss_f32(v3); + } +}; + +template <> +struct SimilarityIP { + static constexpr SIMDLevel SIMD_LEVEL = SIMDLevel::AVX2; + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + const float *y, *yi; + + float accu; + + explicit SimilarityIP(const float* y) : y(y) {} + + simd8float32 accu8; + + FAISS_ALWAYS_INLINE void begin_8() { + accu8.clear(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) { + __m256 yiv = _mm256_loadu_ps(yi); + yi += 8; + accu8.f = _mm256_fmadd_ps(yiv, x.f, accu8.f); + } + + FAISS_ALWAYS_INLINE void add_8_components_2( + simd8float32 x1, + simd8float32 x2) { + accu8.f = _mm256_fmadd_ps(x1.f, x2.f, accu8.f); + } + + FAISS_ALWAYS_INLINE float result_8() { + const __m128 sum = _mm_add_ps( + _mm256_castps256_ps128(accu8.f), + _mm256_extractf128_ps(accu8.f, 1)); + const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); + const __m128 v1 = _mm_add_ps(sum, v0); + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + const __m128 v3 = _mm_add_ps(v1, v2); + return _mm_cvtss_f32(v3); + } +}; + +/****************************************** Specialization of distance computers + */ + +template +struct DCTemplate : SQDistanceComputer { + using Sim = Similarity; + + Quantizer quant; + + DCTemplate(size_t d, const std::vector& trained) + : quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin_8(); + for (size_t i = 0; i < quant.d; i += 8) { + simd8float32 xi = quant.reconstruct_8_components(code, i); + sim.add_8_components(xi); + } + return sim.result_8(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin_8(); + for (size_t i = 0; i < quant.d; i += 8) { + simd8float32 x1 = quant.reconstruct_8_components(code1, i); + simd8float32 x2 = quant.reconstruct_8_components(code2, i); + sim.add_8_components_2(x1, x2); + } + return sim.result_8(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; + +template +struct DistanceComputerByte : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + // __m256i accu = _mm256_setzero_ps (); + __m256i accu = _mm256_setzero_si256(); + for (int i = 0; i < d; i += 16) { + // load 16 bytes, convert to 16 uint16_t + __m256i c1 = _mm256_cvtepu8_epi16( + _mm_loadu_si128((__m128i*)(code1 + i))); + __m256i c2 = _mm256_cvtepu8_epi16( + _mm_loadu_si128((__m128i*)(code2 + i))); + __m256i prod32; + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + prod32 = _mm256_madd_epi16(c1, c2); + } else { + __m256i diff = _mm256_sub_epi16(c1, c2); + prod32 = _mm256_madd_epi16(diff, diff); + } + accu = _mm256_add_epi32(accu, prod32); + } + __m128i sum = _mm256_extractf128_si256(accu, 0); + sum = _mm_add_epi32(sum, _mm256_extractf128_si256(accu, 1)); + sum = _mm_hadd_epi32(sum, sum); + sum = _mm_hadd_epi32(sum, sum); + return _mm_cvtsi128_si32(sum); + } + + void set_query(const float* x) final { + /* + for (int i = 0; i < d; i += 8) { + __m256 xi = _mm256_loadu_ps (x + i); + __m256i ci = _mm256_cvtps_epi32(xi); + */ + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_code_distance(tmp.data(), code); + } +}; + +// explicit instantiation + +template ScalarQuantizer::SQuantizer* select_quantizer_1( + QuantizerType qtype, + size_t d, + const std::vector& trained); + +template SQDistanceComputer* select_distance_computer_1( + MetricType metric_type, + QuantizerType qtype, + size_t d, + const std::vector& trained); + +template InvertedListScanner* sel0_InvertedListScanner( + MetricType mt, + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool by_residual); + +} // namespace scalar_quantizer + +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/impl-avx512.cpp b/faiss/impl/scalar_quantizer/impl-avx512.cpp new file mode 100644 index 0000000000..ad782c61cb --- /dev/null +++ b/faiss/impl/scalar_quantizer/impl-avx512.cpp @@ -0,0 +1,403 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include + +#if defined(__AVX512F__) && defined(__F16C__) +#define USE_AVX512_F16C +#else +#warning "Wrong compiler flags for AVX512_F16C" +#endif + +namespace faiss { + +namespace scalar_quantizer { + +/******************************** Codec specializations */ + +template <> +struct Codec8bit : Codec8bit { + static FAISS_ALWAYS_INLINE simd16float32 + decode_16_components(const uint8_t* code, int i) { + const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i)); + const __m512i i32 = _mm512_cvtepu8_epi32(c16); + const __m512 f16 = _mm512_cvtepi32_ps(i32); + const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f); + const __m512 one_255 = _mm512_set1_ps(1.f / 255.f); + return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255)); + } +}; + +template <> +struct Codec4bit : Codec4bit { + static FAISS_ALWAYS_INLINE simd16float32 + decode_16_components(const uint8_t* code, int i) { + uint64_t c8 = *(uint64_t*)(code + (i >> 1)); + uint64_t mask = 0x0f0f0f0f0f0f0f0f; + uint64_t c8ev = c8 & mask; + uint64_t c8od = (c8 >> 4) & mask; + + __m128i c16 = + _mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od)); + __m256i c8lo = _mm256_cvtepu8_epi32(c16); + __m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8)); + __m512i i16 = _mm512_castsi256_si512(c8lo); + i16 = _mm512_inserti32x8(i16, c8hi, 1); + __m512 f16 = _mm512_cvtepi32_ps(i16); + const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f); + const __m512 one_255 = _mm512_set1_ps(1.f / 15.f); + return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255)); + } +}; + +template <> +struct Codec6bit : Codec6bit { + static FAISS_ALWAYS_INLINE simd16float32 + decode_16_components(const uint8_t* code, int i) { + // pure AVX512 implementation (not necessarily the fastest). + // see: + // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h + + // clang-format off + + // 16 components, 16x6 bit=12 bytes + const __m128i bit_6v = + _mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3); + const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v); + + // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F + // 00 01 02 03 + const __m256i shuffle_mask = _mm256_setr_epi16( + 0xFF00, 0x0100, 0x0201, 0xFF02, + 0xFF03, 0x0403, 0x0504, 0xFF05, + 0xFF06, 0x0706, 0x0807, 0xFF08, + 0xFF09, 0x0A09, 0x0B0A, 0xFF0B); + const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask); + + // 0: xxxxxxxx xx543210 + // 1: xxxx5432 10xxxxxx + // 2: xxxxxx54 3210xxxx + // 3: xxxxxxxx 543210xx + const __m256i shift_right_v = _mm256_setr_epi16( + 0x0U, 0x6U, 0x4U, 0x2U, + 0x0U, 0x6U, 0x4U, 0x2U, + 0x0U, 0x6U, 0x4U, 0x2U, + 0x0U, 0x6U, 0x4U, 0x2U); + __m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v); + + // remove unneeded bits + shuffled_shifted = + _mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F)); + + // scale + const __m512 f8 = + _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted)); + const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f); + const __m512 one_255 = _mm512_set1_ps(1.f / 63.f); + return simd16float32(_mm512_fmadd_ps(f8, one_255, half_one_255)); + + // clang-format on + } +}; + +/******************************** Quantizer specializations */ + +template +struct QuantizerT + : QuantizerT { + QuantizerT(size_t d, const std::vector& trained) + : QuantizerT( + d, + trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m512 xi = Codec::decode_16_components(code, i).f; + return simd16float32(_mm512_fmadd_ps( + xi, _mm512_set1_ps(this->vdiff), _mm512_set1_ps(this->vmin))); + } +}; + +template +struct QuantizerT + : QuantizerT { + QuantizerT(size_t d, const std::vector& trained) + : QuantizerT( + d, + trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m512 xi = Codec::decode_16_components(code, i).f; + return simd16float32(_mm512_fmadd_ps( + xi, + _mm512_loadu_ps(this->vdiff + i), + _mm512_loadu_ps(this->vmin + i))); + } +}; + +template <> +struct QuantizerFP16 : QuantizerFP16 { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16(d, trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m256i codei = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); + return simd16float32(_mm512_cvtph_ps(codei)); + } +}; + +template <> +struct QuantizerBF16 : QuantizerBF16 { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16(d, trained) {} + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); + __m512i code_512i = _mm512_cvtepu16_epi32(code_256i); + code_512i = _mm512_slli_epi32(code_512i, 16); + return simd16float32(_mm512_castsi512_ps(code_512i)); + } +}; + +template <> +struct Quantizer8bitDirect + : Quantizer8bitDirect { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect(d, trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 + __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 + return simd16float32(_mm512_cvtepi32_ps(y16)); // 16 * float32 + } +}; + +template <> +struct Quantizer8bitDirectSigned + : Quantizer8bitDirectSigned { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned(d, trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 + __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 + __m512i c16 = _mm512_set1_epi32(128); + __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes + return simd16float32(_mm512_cvtepi32_ps(z16)); // 16 * float32 + } +}; + +/****************************************** Specialization of similarities */ + +template <> +struct SimilarityL2 { + static constexpr SIMDLevel SIMD_LEVEL = SIMDLevel::AVX512F; + static constexpr int simdwidth = 16; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + + explicit SimilarityL2(const float* y) : y(y) {} + simd16float32 accu16; + + FAISS_ALWAYS_INLINE void begin_16() { + accu16.clear(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_16_components(simd16float32 x) { + __m512 yiv = _mm512_loadu_ps(yi); + yi += 16; + __m512 tmp = _mm512_sub_ps(yiv, x.f); + accu16 = simd16float32(_mm512_fmadd_ps(tmp, tmp, accu16.f)); + } + + FAISS_ALWAYS_INLINE void add_16_components_2( + simd16float32 x, + simd16float32 y_2) { + __m512 tmp = _mm512_sub_ps(y_2.f, x.f); + accu16 = simd16float32(_mm512_fmadd_ps(tmp, tmp, accu16.f)); + } + + FAISS_ALWAYS_INLINE float result_16() { + // performs better than dividing into _mm256 and adding + return _mm512_reduce_add_ps(accu16.f); + } +}; + +template <> +struct SimilarityIP { + static constexpr SIMDLevel SIMD_LEVEL = SIMDLevel::AVX512F; + static constexpr int simdwidth = 16; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + const float *y, *yi; + + float accu; + + explicit SimilarityIP(const float* y) : y(y) {} + + simd16float32 accu16; + + FAISS_ALWAYS_INLINE void begin_16() { + accu16.clear(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_16_components(simd16float32 x) { + __m512 yiv = _mm512_loadu_ps(yi); + yi += 16; + accu16.f = _mm512_fmadd_ps(yiv, x.f, accu16.f); + } + + FAISS_ALWAYS_INLINE void add_16_components_2( + simd16float32 x1, + simd16float32 x2) { + accu16.f = _mm512_fmadd_ps(x1.f, x2.f, accu16.f); + } + + FAISS_ALWAYS_INLINE float result_16() { + // performs better than dividing into _mm256 and adding + return _mm512_reduce_add_ps(accu16.f); + } +}; + +/****************************************** Specialization of distance computers + */ + +template +struct DCTemplate + : SQDistanceComputer { // Update to handle 16 lanes + using Sim = Similarity; + + Quantizer quant; + + DCTemplate(size_t d, const std::vector& trained) + : quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin_16(); + for (size_t i = 0; i < quant.d; i += 16) { + simd16float32 xi = quant.reconstruct_16_components(code, i); + sim.add_16_components(xi); + } + return sim.result_16(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin_16(); + for (size_t i = 0; i < quant.d; i += 16) { + simd16float32 x1 = quant.reconstruct_16_components(code1, i); + simd16float32 x2 = quant.reconstruct_16_components(code2, i); + sim.add_16_components_2(x1, x2); + } + return sim.result_16(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; + +template +struct DistanceComputerByte + : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + __m512i accu = _mm512_setzero_si512(); + for (int i = 0; i < d; i += 32) { // Process 32 bytes at a time + __m512i c1 = _mm512_cvtepu8_epi16( + _mm256_loadu_si256((__m256i*)(code1 + i))); + __m512i c2 = _mm512_cvtepu8_epi16( + _mm256_loadu_si256((__m256i*)(code2 + i))); + __m512i prod32; + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + prod32 = _mm512_madd_epi16(c1, c2); + } else { + __m512i diff = _mm512_sub_epi16(c1, c2); + prod32 = _mm512_madd_epi16(diff, diff); + } + accu = _mm512_add_epi32(accu, prod32); + } + // Horizontally add elements of accu + return _mm512_reduce_add_epi32(accu); + } + + void set_query(const float* x) final { + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_code_distance(tmp.data(), code); + } +}; + +// explicit instantiation + +template ScalarQuantizer::SQuantizer* select_quantizer_1( + QuantizerType qtype, + size_t d, + const std::vector& trained); + +template SQDistanceComputer* select_distance_computer_1( + MetricType metric_type, + QuantizerType qtype, + size_t d, + const std::vector& trained); + +template InvertedListScanner* sel0_InvertedListScanner( + MetricType mt, + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool by_residual); + +} // namespace scalar_quantizer + +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/impl-neon.cpp b/faiss/impl/scalar_quantizer/impl-neon.cpp new file mode 100644 index 0000000000..7daeeae5b6 --- /dev/null +++ b/faiss/impl/scalar_quantizer/impl-neon.cpp @@ -0,0 +1,375 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#if defined(__aarch64__) +#if defined(__GNUC__) && __GNUC__ < 8 +#warning \ + "Cannot enable NEON optimizations in scalar quantizer if the compiler is GCC<8" +#else +#define USE_NEON +#endif +#endif + +namespace faiss { + +namespace scalar_quantizer { +/******************************** Codec specializations */ + +template <> +struct Codec8bit { + static FAISS_ALWAYS_INLINE decode_8_components(const uint8_t* code, int i) { + float32_t result[8] = {}; + for (size_t j = 0; j < 8; j++) { + result[j] = decode_component(code, i + j); + } + float32x4_t res1 = vld1q_f32(result); + float32x4_t res2 = vld1q_f32(result + 4); + return simd8float32(float32x4x2_t{res1, res2}); + } +}; + +template <> +struct Codec4bit { + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + float32_t result[8] = {}; + for (size_t j = 0; j < 8; j++) { + result[j] = decode_component(code, i + j); + } + float32x4_t res1 = vld1q_f32(result); + float32x4_t res2 = vld1q_f32(result + 4); + return simd8float32({res1, res2}); + } +}; + +template <> +struct Codec6bit { + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + float32_t result[8] = {}; + for (size_t j = 0; j < 8; j++) { + result[j] = decode_component(code, i + j); + } + float32x4_t res1 = vld1q_f32(result); + float32x4_t res2 = vld1q_f32(result + 4); + return simd8float32(float32x4x2_t({res1, res2})); + } +}; +/******************************** Quantizatoin specializations */ + +template +struct QuantizerT + : QuantizerT { + QuantizerT(size_t d, const std::vector& trained) + : QuantizerT(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + float32x4x2_t xi = Codec::decode_8_components(code, i); + return simd8float32(float32x4x2_t( + {vfmaq_f32( + vdupq_n_f32(this->vmin), + xi.val[0], + vdupq_n_f32(this->vdiff)), + vfmaq_f32( + vdupq_n_f32(this->vmin), + xi.val[1], + vdupq_n_f32(this->vdiff))})); + } +}; + +template +struct QuantizerT + : QuantizerT { + QuantizerT(size_t d, const std::vector& trained) + : QuantizerT(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + float32x4x2_t xi = Codec::decode_8_components(code, i).data; + + float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i); + float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i); + + return simd8float32( + {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]), + vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])}); + } +}; + +template <> +struct QuantizerFP16 : QuantizerFP16 { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); + return simd8float32( + {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])), + vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))}); + } +}; + +template <> +struct QuantizerBF16 : QuantizerBF16 { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); + return simd8float32( + {vreinterpretq_f32_u32( + vshlq_n_u32(vmovl_u16(codei.val[0]), 16)), + vreinterpretq_f32_u32( + vshlq_n_u32(vmovl_u16(codei.val[1]), 16))}); + } +}; + +template <> +struct Quantizer8bitDirect + : Quantizer8bitDirect { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); + uint16x8_t y8 = vmovl_u8(x8); + uint16x4_t y8_0 = vget_low_u16(y8); + uint16x4_t y8_1 = vget_high_u16(y8); + + // convert uint16 -> uint32 -> fp32 + return simd8float32( + {vcvtq_f32_u32(vmovl_u16(y8_0)), + vcvtq_f32_u32(vmovl_u16(y8_1))}); + } +}; + +template <> +struct Quantizer8bitDirectSigned + : Quantizer8bitDirectSigned { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); + uint16x8_t y8 = vmovl_u8(x8); // convert uint8 -> uint16 + uint16x4_t y8_0 = vget_low_u16(y8); + uint16x4_t y8_1 = vget_high_u16(y8); + + float32x4_t z8_0 = vcvtq_f32_u32( + vmovl_u16(y8_0)); // convert uint16 -> uint32 -> fp32 + float32x4_t z8_1 = vcvtq_f32_u32(vmovl_u16(y8_1)); + + // subtract 128 to convert into signed numbers + return simd8float32( + {vsubq_f32(z8_0, vmovq_n_f32(128.0)), + vsubq_f32(z8_1, vmovq_n_f32(128.0))}); + } +}; + +/****************************************** Specialization of similarities */ + +template <> +struct SimilarityL2 { + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + explicit SimilarityL2(const float* y) : y(y) {} + simd8float32 accu8; + + FAISS_ALWAYS_INLINE void begin_8() { + accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; + yi = y; + } + + FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) { + float32x4x2_t yiv = vld1q_f32_x2(yi); + yi += 8; + + float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]); + float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]); + + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); + + accu8 = simd8float32({accu8_0, accu8_1}); + } + + FAISS_ALWAYS_INLINE void add_8_components_2( + simd8float32 x, + simd8float32 y) { + float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]); + float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]); + + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); + + accu8 = simd8float32({accu8_0, accu8_1}); + } + + FAISS_ALWAYS_INLINE float result_8() { + float32x4_t sum_0 = vpaddq_f32(accu8.data.val[0], accu8.data.val[0]); + float32x4_t sum_1 = vpaddq_f32(accu8.data.val[1], accu8.data.val[1]); + + float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0); + float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1); + return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0); + } +}; + +template <> +struct SimilarityIP { + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + const float *y, *yi; + + explicit SimilarityIP(const float* y) : y(y) {} + float32x4x2_t accu8; + + FAISS_ALWAYS_INLINE void begin_8() { + accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; + yi = y; + } + + FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) { + float32x4x2_t yiv = vld1q_f32_x2(yi); + yi += 8; + + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]); + accu8 = {accu8_0, accu8_1}; + } + + FAISS_ALWAYS_INLINE void add_8_components_2( + float32x4x2_t x1, + float32x4x2_t x2) { + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]); + accu8 = {accu8_0, accu8_1}; + } + + FAISS_ALWAYS_INLINE float result_8() { + float32x4x2_t sum = { + vpaddq_f32(accu8.val[0], accu8.val[0]), + vpaddq_f32(accu8.val[1], accu8.val[1])}; + + float32x4x2_t sum2 = { + vpaddq_f32(sum.val[0], sum.val[0]), + vpaddq_f32(sum.val[1], sum.val[1])}; + return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0); + } +}; + +/****************************************** Specialization of distance computers + */ + +// this is the same code as the AVX2 version... Possible to mutualize? +template +struct DCTemplate : SQDistanceComputer { + using Sim = Similarity; + + Quantizer quant; + + DCTemplate(size_t d, const std::vector& trained) + : quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin_8(); + for (size_t i = 0; i < quant.d; i += 8) { + simd8float32 xi = quant.reconstruct_8_components(code, i); + sim.add_8_components(xi); + } + return sim.result_8(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin_8(); + for (size_t i = 0; i < quant.d; i += 8) { + simd8float32 x1 = quant.reconstruct_8_components(code1, i); + simd8float32 x2 = quant.reconstruct_8_components(code2, i); + sim.add_8_components_2(x1, x2); + } + return sim.result_8(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; + +template +struct DistanceComputerByte + : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + int accu = 0; + for (int i = 0; i < d; i++) { + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + accu += int(code1[i]) * code2[i]; + } else { + int diff = int(code1[i]) - code2[i]; + accu += diff * diff; + } + } + return accu; + } + + void set_query(const float* x) final { + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_code_distance(tmp.data(), code); + } +}; + +} // namespace scalar_quantizer + +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/quantizers.h b/faiss/impl/scalar_quantizer/quantizers.h new file mode 100644 index 0000000000..a9865722d6 --- /dev/null +++ b/faiss/impl/scalar_quantizer/quantizers.h @@ -0,0 +1,293 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include + +#include +#include + +namespace faiss { + +namespace scalar_quantizer { + +using QuantizerType = ScalarQuantizer::QuantizerType; + +/******************************************************************* + * Quantizer: normalizes scalar vector components, then passes them + * through a codec + *******************************************************************/ + +enum class QScaling { UNIFORM = 0, NON_UNIFORM = 1 }; + +template +struct QuantizerT {}; + +template +struct QuantizerT + : ScalarQuantizer::SQuantizer { + const size_t d; + const float vmin, vdiff; + + QuantizerT(size_t d, const std::vector& trained) + : d(d), vmin(trained[0]), vdiff(trained[1]) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + float xi = 0; + if (vdiff != 0) { + xi = (x[i] - vmin) / vdiff; + if (xi < 0) { + xi = 0; + } + if (xi > 1.0) { + xi = 1.0; + } + } + Codec::encode_component(xi, code, i); + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + float xi = Codec::decode_component(code, i); + x[i] = vmin + xi * vdiff; + } + } + + FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) + const { + float xi = Codec::decode_component(code, i); + return vmin + xi * vdiff; + } +}; + +template +struct QuantizerT + : ScalarQuantizer::SQuantizer { + const size_t d; + const float *vmin, *vdiff; + + QuantizerT(size_t d, const std::vector& trained) + : d(d), vmin(trained.data()), vdiff(trained.data() + d) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + float xi = 0; + if (vdiff[i] != 0) { + xi = (x[i] - vmin[i]) / vdiff[i]; + if (xi < 0) { + xi = 0; + } + if (xi > 1.0) { + xi = 1.0; + } + } + Codec::encode_component(xi, code, i); + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + float xi = Codec::decode_component(code, i); + x[i] = vmin[i] + xi * vdiff[i]; + } + } + + FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) + const { + float xi = Codec::decode_component(code, i); + return vmin[i] + xi * vdiff[i]; + } +}; + +/******************************************************************* + * Quantizers that are not based on codecs + *******************************************************************/ + +/******************************************************************* + * FP16 quantizer + *******************************************************************/ + +template +struct QuantizerFP16 {}; + +template <> +struct QuantizerFP16 : ScalarQuantizer::SQuantizer { + const size_t d; + + QuantizerFP16(size_t d, const std::vector& /* unused */) : d(d) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + ((uint16_t*)code)[i] = encode_fp16(x[i]); + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + x[i] = decode_fp16(((uint16_t*)code)[i]); + } + } + + FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) + const { + return decode_fp16(((uint16_t*)code)[i]); + } +}; + +/******************************************************************* + * BF16 quantizer + *******************************************************************/ + +template +struct QuantizerBF16 {}; + +template <> +struct QuantizerBF16 : ScalarQuantizer::SQuantizer { + const size_t d; + + QuantizerBF16(size_t d, const std::vector& /* unused */) : d(d) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + ((uint16_t*)code)[i] = encode_bf16(x[i]); + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + x[i] = decode_bf16(((uint16_t*)code)[i]); + } + } + + FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) + const { + return decode_bf16(((uint16_t*)code)[i]); + } +}; + +/******************************************************************* + * 8bit_direct quantizer + *******************************************************************/ + +template +struct Quantizer8bitDirect {}; + +template <> +struct Quantizer8bitDirect : ScalarQuantizer::SQuantizer { + const size_t d; + + Quantizer8bitDirect(size_t d, const std::vector& /* unused */) + : d(d) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + code[i] = (uint8_t)x[i]; + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + x[i] = code[i]; + } + } + + FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) + const { + return code[i]; + } +}; + +/******************************************************************* + * 8bit_direct_signed quantizer + *******************************************************************/ + +template +struct Quantizer8bitDirectSigned {}; + +template <> +struct Quantizer8bitDirectSigned + : ScalarQuantizer::SQuantizer { + const size_t d; + + Quantizer8bitDirectSigned(size_t d, const std::vector& /* unused */) + : d(d) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + code[i] = (uint8_t)(x[i] + 128); + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + x[i] = code[i] - 128; + } + } + + FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) + const { + return code[i] - 128; + } +}; + +template +ScalarQuantizer::SQuantizer* select_quantizer_1( + QuantizerType qtype, + size_t d, + const std::vector& trained) { + // constexpr SIMDLevel SL = INSTANCIATE_SIMD_LEVEL; + constexpr QScaling NU = QScaling::NON_UNIFORM; + constexpr QScaling U = QScaling::UNIFORM; + switch (qtype) { + case ScalarQuantizer::QT_8bit: + return new QuantizerT, NU, SL>(d, trained); + + case ScalarQuantizer::QT_6bit: + return new QuantizerT, NU, SL>(d, trained); + case ScalarQuantizer::QT_4bit: + return new QuantizerT, NU, SL>(d, trained); + case ScalarQuantizer::QT_8bit_uniform: + return new QuantizerT, U, SL>(d, trained); + case ScalarQuantizer::QT_4bit_uniform: + return new QuantizerT, U, SL>(d, trained); + case ScalarQuantizer::QT_fp16: + return new QuantizerFP16(d, trained); + case ScalarQuantizer::QT_bf16: + return new QuantizerBF16(d, trained); + case ScalarQuantizer::QT_8bit_direct: + return new Quantizer8bitDirect(d, trained); + case ScalarQuantizer::QT_8bit_direct_signed: + return new Quantizer8bitDirectSigned(d, trained); + default: + FAISS_THROW_MSG("unknown qtype"); + return nullptr; + } +} + +// prevent implicit instanciation +extern template ScalarQuantizer::SQuantizer* select_quantizer_1< + SIMDLevel::AVX2>( + QuantizerType qtype, + size_t d, + const std::vector& trained); + +extern template ScalarQuantizer::SQuantizer* select_quantizer_1< + SIMDLevel::AVX512F>( + QuantizerType qtype, + size_t d, + const std::vector& trained); + +} // namespace scalar_quantizer + +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/scanners.h b/faiss/impl/scalar_quantizer/scanners.h new file mode 100644 index 0000000000..17f1608582 --- /dev/null +++ b/faiss/impl/scalar_quantizer/scanners.h @@ -0,0 +1,356 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace faiss { + +namespace scalar_quantizer { + +/******************************************************************* + * IndexScalarQuantizer/IndexIVFScalarQuantizer scanner object + * + * It is an InvertedListScanner, but is designed to work with + * IndexScalarQuantizer as well. + ********************************************************************/ + +template +struct IVFSQScannerIP : InvertedListScanner { + DCClass dc; + bool by_residual; + + float accu0; /// added to all distances + + IVFSQScannerIP( + int d, + const std::vector& trained, + size_t code_size, + bool store_pairs, + const IDSelector* sel, + bool by_residual) + : dc(d, trained), by_residual(by_residual), accu0(0) { + this->store_pairs = store_pairs; + this->sel = sel; + this->code_size = code_size; + this->keep_max = true; + } + + void set_query(const float* query) override { + dc.set_query(query); + } + + void set_list(idx_t list_no, float coarse_dis) override { + this->list_no = list_no; + accu0 = by_residual ? coarse_dis : 0; + } + + float distance_to_code(const uint8_t* code) const final { + return accu0 + dc.query_to_code(code); + } + + size_t scan_codes( + size_t list_size, + const uint8_t* codes, + const idx_t* ids, + float* simi, + idx_t* idxi, + size_t k) const override { + size_t nup = 0; + + for (size_t j = 0; j < list_size; j++, codes += code_size) { + if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { + continue; + } + + float accu = accu0 + dc.query_to_code(codes); + + if (accu > simi[0]) { + int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + minheap_replace_top(k, simi, idxi, accu, id); + nup++; + } + } + return nup; + } + + void scan_codes_range( + size_t list_size, + const uint8_t* codes, + const idx_t* ids, + float radius, + RangeQueryResult& res) const override { + for (size_t j = 0; j < list_size; j++, codes += code_size) { + if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { + continue; + } + + float accu = accu0 + dc.query_to_code(codes); + if (accu > radius) { + int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + res.add(accu, id); + } + } + } +}; + +/* use_sel = 0: don't check selector + * = 1: check on ids[j] + * = 2: check in j directly (normally ids is nullptr and store_pairs) + */ +template +struct IVFSQScannerL2 : InvertedListScanner { + DCClass dc; + + bool by_residual; + const Index* quantizer; + const float* x; /// current query + + std::vector tmp; + + IVFSQScannerL2( + int d, + const std::vector& trained, + size_t code_size, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool by_residual) + : dc(d, trained), + by_residual(by_residual), + quantizer(quantizer), + x(nullptr), + tmp(d) { + this->store_pairs = store_pairs; + this->sel = sel; + this->code_size = code_size; + } + + void set_query(const float* query) override { + x = query; + if (!quantizer) { + dc.set_query(query); + } + } + + void set_list(idx_t list_no, float /*coarse_dis*/) override { + this->list_no = list_no; + if (by_residual) { + // shift of x_in wrt centroid + quantizer->compute_residual(x, tmp.data(), list_no); + dc.set_query(tmp.data()); + } else { + dc.set_query(x); + } + } + + float distance_to_code(const uint8_t* code) const final { + return dc.query_to_code(code); + } + + size_t scan_codes( + size_t list_size, + const uint8_t* codes, + const idx_t* ids, + float* simi, + idx_t* idxi, + size_t k) const override { + size_t nup = 0; + for (size_t j = 0; j < list_size; j++, codes += code_size) { + if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { + continue; + } + + float dis = dc.query_to_code(codes); + + if (dis < simi[0]) { + int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + maxheap_replace_top(k, simi, idxi, dis, id); + nup++; + } + } + return nup; + } + + void scan_codes_range( + size_t list_size, + const uint8_t* codes, + const idx_t* ids, + float radius, + RangeQueryResult& res) const override { + for (size_t j = 0; j < list_size; j++, codes += code_size) { + if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { + continue; + } + + float dis = dc.query_to_code(codes); + if (dis < radius) { + int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + res.add(dis, id); + } + } + } +}; + +/* Select the right implementation by dispatching to templatized versions */ + +template +InvertedListScanner* sel3_InvertedListScanner( + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool r) { + if (DCClass::Sim::metric_type == METRIC_L2) { + return new IVFSQScannerL2( + sq->d, + sq->trained, + sq->code_size, + quantizer, + store_pairs, + sel, + r); + } else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) { + return new IVFSQScannerIP( + sq->d, sq->trained, sq->code_size, store_pairs, sel, r); + } else { + FAISS_THROW_MSG("unsupported metric type"); + } +} + +template +InvertedListScanner* sel2_InvertedListScanner( + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool r) { + if (sel) { + if (store_pairs) { + return sel3_InvertedListScanner( + sq, quantizer, store_pairs, sel, r); + } else { + return sel3_InvertedListScanner( + sq, quantizer, store_pairs, sel, r); + } + } else { + return sel3_InvertedListScanner( + sq, quantizer, store_pairs, sel, r); + } +} + +template +InvertedListScanner* sel12_InvertedListScanner( + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool r) { + constexpr SIMDLevel SL = Similarity::SIMD_LEVEL; + using QuantizerClass = QuantizerT; + using DCClass = DCTemplate; + return sel2_InvertedListScanner( + sq, quantizer, store_pairs, sel, r); +} + +template +InvertedListScanner* sel1_InvertedListScanner( + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool r) { + constexpr SIMDLevel SL = Sim::SIMD_LEVEL; + constexpr QScaling NU = QScaling::NON_UNIFORM; + constexpr QScaling U = QScaling::UNIFORM; + + switch (sq->qtype) { + case ScalarQuantizer::QT_8bit_uniform: + return sel12_InvertedListScanner, U>( + sq, quantizer, store_pairs, sel, r); + case ScalarQuantizer::QT_4bit_uniform: + return sel12_InvertedListScanner, U>( + sq, quantizer, store_pairs, sel, r); + case ScalarQuantizer::QT_8bit: + return sel12_InvertedListScanner, NU>( + sq, quantizer, store_pairs, sel, r); + case ScalarQuantizer::QT_4bit: + return sel12_InvertedListScanner, NU>( + sq, quantizer, store_pairs, sel, r); + case ScalarQuantizer::QT_6bit: + return sel12_InvertedListScanner, NU>( + sq, quantizer, store_pairs, sel, r); + case ScalarQuantizer::QT_fp16: + return sel2_InvertedListScanner< + DCTemplate, Sim, SL>>( + sq, quantizer, store_pairs, sel, r); + + case ScalarQuantizer::QT_bf16: + return sel2_InvertedListScanner< + DCTemplate, Sim, SL>>( + sq, quantizer, store_pairs, sel, r); + case ScalarQuantizer::QT_8bit_direct: + return sel2_InvertedListScanner< + DCTemplate, Sim, SL>>( + sq, quantizer, store_pairs, sel, r); + + case ScalarQuantizer::QT_8bit_direct_signed: + return sel2_InvertedListScanner< + DCTemplate, Sim, SL>>( + sq, quantizer, store_pairs, sel, r); + default: + FAISS_THROW_MSG("unknown qtype"); + return nullptr; + } +} + +template +InvertedListScanner* sel0_InvertedListScanner( + MetricType mt, + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool by_residual) { + if (mt == METRIC_L2) { + return sel1_InvertedListScanner>( + sq, quantizer, store_pairs, sel, by_residual); + } else if (mt == METRIC_INNER_PRODUCT) { + return sel1_InvertedListScanner>( + sq, quantizer, store_pairs, sel, by_residual); + } else { + FAISS_THROW_MSG("unsupported metric type"); + } +} + +// prevent implicit instantiation of the template when there are +// SIMD optimized versions... +extern template InvertedListScanner* sel0_InvertedListScanner( + MetricType mt, + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool by_residual); + +extern template InvertedListScanner* sel0_InvertedListScanner< + SIMDLevel::AVX512F>( + MetricType mt, + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool by_residual); + +} // namespace scalar_quantizer +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/training.cpp b/faiss/impl/scalar_quantizer/training.cpp new file mode 100644 index 0000000000..23c51384fd --- /dev/null +++ b/faiss/impl/scalar_quantizer/training.cpp @@ -0,0 +1,188 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace faiss { + +namespace scalar_quantizer { +/******************************************************************* + * Quantizer range training + */ + +static float sqr(float x) { + return x * x; +} + +void train_Uniform( + RangeStat rs, + float rs_arg, + idx_t n, + int k, + const float* x, + std::vector& trained) { + trained.resize(2); + float& vmin = trained[0]; + float& vmax = trained[1]; + + if (rs == ScalarQuantizer::RS_minmax) { + vmin = HUGE_VAL; + vmax = -HUGE_VAL; + for (size_t i = 0; i < n; i++) { + if (x[i] < vmin) + vmin = x[i]; + if (x[i] > vmax) + vmax = x[i]; + } + float vexp = (vmax - vmin) * rs_arg; + vmin -= vexp; + vmax += vexp; + } else if (rs == ScalarQuantizer::RS_meanstd) { + double sum = 0, sum2 = 0; + for (size_t i = 0; i < n; i++) { + sum += x[i]; + sum2 += x[i] * x[i]; + } + float mean = sum / n; + float var = sum2 / n - mean * mean; + float std = var <= 0 ? 1.0 : sqrt(var); + + vmin = mean - std * rs_arg; + vmax = mean + std * rs_arg; + } else if (rs == ScalarQuantizer::RS_quantiles) { + std::vector x_copy(n); + memcpy(x_copy.data(), x, n * sizeof(*x)); + // TODO just do a quickselect + std::sort(x_copy.begin(), x_copy.end()); + int o = int(rs_arg * n); + if (o < 0) + o = 0; + if (o > n - o) + o = n / 2; + vmin = x_copy[o]; + vmax = x_copy[n - 1 - o]; + + } else if (rs == ScalarQuantizer::RS_optim) { + float a, b; + float sx = 0; + { + vmin = HUGE_VAL, vmax = -HUGE_VAL; + for (size_t i = 0; i < n; i++) { + if (x[i] < vmin) + vmin = x[i]; + if (x[i] > vmax) + vmax = x[i]; + sx += x[i]; + } + b = vmin; + a = (vmax - vmin) / (k - 1); + } + int verbose = false; + int niter = 2000; + float last_err = -1; + int iter_last_err = 0; + for (int it = 0; it < niter; it++) { + float sn = 0, sn2 = 0, sxn = 0, err1 = 0; + + for (idx_t i = 0; i < n; i++) { + float xi = x[i]; + float ni = floor((xi - b) / a + 0.5); + if (ni < 0) + ni = 0; + if (ni >= k) + ni = k - 1; + err1 += sqr(xi - (ni * a + b)); + sn += ni; + sn2 += ni * ni; + sxn += ni * xi; + } + + if (err1 == last_err) { + iter_last_err++; + if (iter_last_err == 16) + break; + } else { + last_err = err1; + iter_last_err = 0; + } + + float det = sqr(sn) - sn2 * n; + + b = (sn * sxn - sn2 * sx) / det; + a = (sn * sx - n * sxn) / det; + if (verbose) { + printf("it %d, err1=%g \r", it, err1); + fflush(stdout); + } + } + if (verbose) + printf("\n"); + + vmin = b; + vmax = b + a * (k - 1); + + } else { + FAISS_THROW_MSG("Invalid qtype"); + } + vmax -= vmin; +} + +void train_NonUniform( + RangeStat rs, + float rs_arg, + idx_t n, + int d, + int k, + const float* x, + std::vector& trained) { + trained.resize(2 * d); + float* vmin = trained.data(); + float* vmax = trained.data() + d; + if (rs == ScalarQuantizer::RS_minmax) { + memcpy(vmin, x, sizeof(*x) * d); + memcpy(vmax, x, sizeof(*x) * d); + for (size_t i = 1; i < n; i++) { + const float* xi = x + i * d; + for (size_t j = 0; j < d; j++) { + if (xi[j] < vmin[j]) + vmin[j] = xi[j]; + if (xi[j] > vmax[j]) + vmax[j] = xi[j]; + } + } + float* vdiff = vmax; + for (size_t j = 0; j < d; j++) { + float vexp = (vmax[j] - vmin[j]) * rs_arg; + vmin[j] -= vexp; + vmax[j] += vexp; + vdiff[j] = vmax[j] - vmin[j]; + } + } else { + // transpose + std::vector xt(n * d); + for (size_t i = 1; i < n; i++) { + const float* xi = x + i * d; + for (size_t j = 0; j < d; j++) { + xt[j * n + i] = xi[j]; + } + } + std::vector trained_d(2); +#pragma omp parallel for + for (int j = 0; j < d; j++) { + train_Uniform(rs, rs_arg, n, k, xt.data() + j * n, trained_d); + vmin[j] = trained_d[0]; + vmax[j] = trained_d[1]; + } + } +} + +} // namespace scalar_quantizer + +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/training.h b/faiss/impl/scalar_quantizer/training.h new file mode 100644 index 0000000000..9eeb39b926 --- /dev/null +++ b/faiss/impl/scalar_quantizer/training.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +/******************************************************************* + * Quantizer range training for the scalar quantizer. This is independent of the + * searching code and needs not to be very optimized (scalar quantizer training + * is very efficient). + */ + +#include + +namespace faiss { + +namespace scalar_quantizer { + +using RangeStat = ScalarQuantizer::RangeStat; + +void train_Uniform( + RangeStat rs, + float rs_arg, + idx_t n, + int k, + const float* x, + std::vector& trained); + +void train_NonUniform( + RangeStat rs, + float rs_arg, + idx_t n, + int d, + int k, + const float* x, + std::vector& trained); +} // namespace scalar_quantizer + +} // namespace faiss diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index fa22f76d09..5ecfa0e39c 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -105,7 +105,7 @@ typedef uint64_t size_t; #include #include #include -#include +#include #include #include @@ -324,6 +324,9 @@ namespace std { %include %include +%ignore faiss::SIMDConfig::level_names; +%include + int get_num_gpus(); void gpu_profiler_start(); void gpu_profiler_stop(); @@ -618,7 +621,7 @@ void gpu_sync_all_devices() // NOTE(matthijs) let's not go into wrapping simdlib struct faiss::simd16uint16 {}; -%include +%include %include %include %include diff --git a/faiss/utils/distances.h b/faiss/utils/distances.h index 80d2cfc699..696b246b90 100644 --- a/faiss/utils/distances.h +++ b/faiss/utils/distances.h @@ -15,6 +15,7 @@ #include #include +#include namespace faiss { @@ -27,9 +28,15 @@ struct IDSelector; /// Squared L2 distance between two vectors float fvec_L2sqr(const float* x, const float* y, size_t d); +template +float fvec_L2sqr(const float* x, const float* y, size_t d); + /// inner product float fvec_inner_product(const float* x, const float* y, size_t d); +template +float fvec_inner_product(const float* x, const float* y, size_t d); + /// L1 distance float fvec_L1(const float* x, const float* y, size_t d); @@ -138,6 +145,9 @@ size_t fvec_L2sqr_ny_nearest_y_transposed( /** squared norm of a vector */ float fvec_norm_L2sqr(const float* x, size_t d); +template +float fvec_norm_L2sqr(const float* x, size_t d); + /** compute the L2 norms for a set of vectors * * @param norms output norms, size nx @@ -473,6 +483,10 @@ void compute_PQ_dis_tables_dsub2( */ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c); +/* specialized version for each SIMD level */ +template +void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c); + /** same as fvec_madd, also return index of the min of the result table * @return index of the min of table c */ diff --git a/faiss/utils/distances_simd.cpp b/faiss/utils/distances_simd.cpp index 1990e46aae..e829ae54ec 100644 --- a/faiss/utils/distances_simd.cpp +++ b/faiss/utils/distances_simd.cpp @@ -19,6 +19,9 @@ #include #include +#define AUTOVEC_LEVEL SIMDLevel::NONE +#include + #ifdef __SSE3__ #include #endif @@ -191,43 +194,19 @@ void fvec_inner_products_ny_ref( * Autovectorized implementations */ -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN -float fvec_inner_product(const float* x, const float* y, size_t d) { - float res = 0.F; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i != d; ++i) { - res += x[i] * y[i]; - } - return res; -} -FAISS_PRAGMA_IMPRECISE_FUNCTION_END +// dispatching functions -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN float fvec_norm_L2sqr(const float* x, size_t d) { - // the double in the _ref is suspected to be a typo. Some of the manual - // implementations this replaces used float. - float res = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i != d; ++i) { - res += x[i] * x[i]; - } - - return res; + DISPATCH_SIMDLevel(fvec_norm_L2sqr, x, d); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN float fvec_L2sqr(const float* x, const float* y, size_t d) { - size_t i; - float res = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (i = 0; i < d; i++) { - const float tmp = x[i] - y[i]; - res += tmp * tmp; - } - return res; + DISPATCH_SIMDLevel(fvec_L2sqr, x, y, d); +} + +float fvec_inner_product(const float* x, const float* y, size_t d) { + DISPATCH_SIMDLevel(fvec_inner_product, x, y, d); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END /// Special version of inner product that computes 4 distances /// between x and yi @@ -3246,7 +3225,8 @@ void fvec_inner_products_ny( * heavily optimized table computations ***************************************************************************/ -[[maybe_unused]] static inline void fvec_madd_ref( +template <> +void fvec_madd( size_t n, const float* a, float bf, @@ -3256,94 +3236,6 @@ void fvec_inner_products_ny( c[i] = a[i] + bf * b[i]; } -#if defined(__AVX512F__) - -static inline void fvec_madd_avx512( - const size_t n, - const float* __restrict a, - const float bf, - const float* __restrict b, - float* __restrict c) { - const size_t n16 = n / 16; - const size_t n_for_masking = n % 16; - - const __m512 bfmm = _mm512_set1_ps(bf); - - size_t idx = 0; - for (idx = 0; idx < n16 * 16; idx += 16) { - const __m512 ax = _mm512_loadu_ps(a + idx); - const __m512 bx = _mm512_loadu_ps(b + idx); - const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); - _mm512_storeu_ps(c + idx, abmul); - } - - if (n_for_masking > 0) { - const __mmask16 mask = (1 << n_for_masking) - 1; - - const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx); - const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx); - const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); - _mm512_mask_storeu_ps(c + idx, mask, abmul); - } -} - -#elif defined(__AVX2__) - -static inline void fvec_madd_avx2( - const size_t n, - const float* __restrict a, - const float bf, - const float* __restrict b, - float* __restrict c) { - // - const size_t n8 = n / 8; - const size_t n_for_masking = n % 8; - - const __m256 bfmm = _mm256_set1_ps(bf); - - size_t idx = 0; - for (idx = 0; idx < n8 * 8; idx += 8) { - const __m256 ax = _mm256_loadu_ps(a + idx); - const __m256 bx = _mm256_loadu_ps(b + idx); - const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); - _mm256_storeu_ps(c + idx, abmul); - } - - if (n_for_masking > 0) { - __m256i mask; - switch (n_for_masking) { - case 1: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1); - break; - case 2: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1); - break; - case 3: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1); - break; - case 4: - mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1); - break; - case 5: - mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1); - break; - case 6: - mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1); - break; - case 7: - mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1); - break; - } - - const __m256 ax = _mm256_maskload_ps(a + idx, mask); - const __m256 bx = _mm256_maskload_ps(b + idx, mask); - const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); - _mm256_maskstore_ps(c + idx, mask, abmul); - } -} - -#endif - #ifdef __SSE3__ [[maybe_unused]] static inline void fvec_madd_sse( @@ -3367,16 +3259,7 @@ static inline void fvec_madd_avx2( } void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { -#ifdef __AVX512F__ - fvec_madd_avx512(n, a, bf, b, c); -#elif __AVX2__ - fvec_madd_avx2(n, a, bf, b, c); -#else - if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) - fvec_madd_sse(n, a, bf, b, c); - else - fvec_madd_ref(n, a, bf, b, c); -#endif + DISPATCH_SIMDLevel(fvec_madd, n, a, bf, b, c); } #elif defined(__ARM_FEATURE_SVE) diff --git a/faiss/utils/simd_impl/distances_autovec-inl.h b/faiss/utils/simd_impl/distances_autovec-inl.h new file mode 100644 index 0000000000..ace6dccccb --- /dev/null +++ b/faiss/utils/simd_impl/distances_autovec-inl.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace faiss { + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_norm_L2sqr(const float* x, size_t d) { + // the double in the _ref is suspected to be a typo. Some of the manual + // implementations this replaces used float. + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i != d; ++i) { + res += x[i] * x[i]; + } + + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_L2sqr(const float* x, const float* y, size_t d) { + size_t i; + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (i = 0; i < d; i++) { + const float tmp = x[i] - y[i]; + res += tmp * tmp; + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_inner_product( + const float* x, + const float* y, + size_t d) { + float res = 0.F; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i != d; ++i) { + res += x[i] * y[i]; + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_avx2.cpp b/faiss/utils/simd_impl/distances_avx2.cpp new file mode 100644 index 0000000000..cd52c470f9 --- /dev/null +++ b/faiss/utils/simd_impl/distances_avx2.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#define AUTOVEC_LEVEL SIMDLevel::AVX2 +#include + +namespace faiss { + +template <> +void fvec_madd( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + // + const size_t n8 = n / 8; + const size_t n_for_masking = n % 8; + + const __m256 bfmm = _mm256_set1_ps(bf); + + size_t idx = 0; + for (idx = 0; idx < n8 * 8; idx += 8) { + const __m256 ax = _mm256_loadu_ps(a + idx); + const __m256 bx = _mm256_loadu_ps(b + idx); + const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); + _mm256_storeu_ps(c + idx, abmul); + } + + if (n_for_masking > 0) { + __m256i mask; + switch (n_for_masking) { + case 1: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1); + break; + case 2: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1); + break; + case 3: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1); + break; + case 4: + mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1); + break; + case 5: + mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1); + break; + case 6: + mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1); + break; + case 7: + mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1); + break; + } + + const __m256 ax = _mm256_maskload_ps(a + idx, mask); + const __m256 bx = _mm256_maskload_ps(b + idx, mask); + const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); + _mm256_maskstore_ps(c + idx, mask, abmul); + } +} + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_avx512.cpp b/faiss/utils/simd_impl/distances_avx512.cpp new file mode 100644 index 0000000000..e87df652fd --- /dev/null +++ b/faiss/utils/simd_impl/distances_avx512.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#define AUTOVEC_LEVEL SIMDLevel::AVX512F +#include + +namespace faiss { + +template <> +void fvec_madd( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + const size_t n16 = n / 16; + const size_t n_for_masking = n % 16; + + const __m512 bfmm = _mm512_set1_ps(bf); + + size_t idx = 0; + for (idx = 0; idx < n16 * 16; idx += 16) { + const __m512 ax = _mm512_loadu_ps(a + idx); + const __m512 bx = _mm512_loadu_ps(b + idx); + const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); + _mm512_storeu_ps(c + idx, abmul); + } + + if (n_for_masking > 0) { + const __mmask16 mask = (1 << n_for_masking) - 1; + + const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx); + const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx); + const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); + _mm512_mask_storeu_ps(c + idx, mask, abmul); + } +} + +} // namespace faiss diff --git a/faiss/utils/simdlib_avx2.h b/faiss/utils/simd_impl/simdlib_avx2.h similarity index 100% rename from faiss/utils/simdlib_avx2.h rename to faiss/utils/simd_impl/simdlib_avx2.h diff --git a/faiss/utils/simdlib_avx512.h b/faiss/utils/simd_impl/simdlib_avx512.h similarity index 86% rename from faiss/utils/simdlib_avx512.h rename to faiss/utils/simd_impl/simdlib_avx512.h index 63b23f9b19..80889dc508 100644 --- a/faiss/utils/simdlib_avx512.h +++ b/faiss/utils/simd_impl/simdlib_avx512.h @@ -14,7 +14,7 @@ #include -#include +#include namespace faiss { @@ -293,4 +293,47 @@ struct simd64uint8 : simd512bit { } }; +struct simd16float32 : simd512bit { + simd16float32() {} + + explicit simd16float32(simd512bit x) : simd512bit(x) {} + + explicit simd16float32(__m512 x) : simd512bit(x) {} + + explicit simd16float32(float x) : simd512bit(_mm512_set1_ps(x)) {} + + explicit simd16float32(const float* x) + : simd16float32(_mm512_loadu_ps(x)) {} + + simd16float32 operator*(simd16float32 other) const { + return simd16float32(_mm512_mul_ps(f, other.f)); + } + + simd16float32 operator+(simd16float32 other) const { + return simd16float32(_mm512_add_ps(f, other.f)); + } + + simd16float32 operator-(simd16float32 other) const { + return simd16float32(_mm512_sub_ps(f, other.f)); + } + + simd16float32& operator+=(const simd16float32& other) { + f = _mm512_add_ps(f, other.f); + return *this; + } + + std::string tostring() const { + float tab[16]; + storeu((void*)tab); + char res[1000]; + char* ptr = res; + for (int i = 0; i < 16; i++) { + ptr += sprintf(ptr, "%g,", tab[i]); + } + // strip last , + ptr[-1] = 0; + return std::string(res); + } +}; + } // namespace faiss diff --git a/faiss/utils/simdlib_emulated.h b/faiss/utils/simd_impl/simdlib_emulated.h similarity index 100% rename from faiss/utils/simdlib_emulated.h rename to faiss/utils/simd_impl/simdlib_emulated.h diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simd_impl/simdlib_neon.h similarity index 100% rename from faiss/utils/simdlib_neon.h rename to faiss/utils/simd_impl/simdlib_neon.h diff --git a/faiss/utils/simdlib_ppc64.h b/faiss/utils/simd_impl/simdlib_ppc64.h similarity index 100% rename from faiss/utils/simdlib_ppc64.h rename to faiss/utils/simd_impl/simdlib_ppc64.h diff --git a/faiss/utils/simd_levels.cpp b/faiss/utils/simd_levels.cpp new file mode 100644 index 0000000000..6f245d3078 --- /dev/null +++ b/faiss/utils/simd_levels.cpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace faiss { + +SIMDLevel SIMDConfig::level = SIMDLevel::NONE; + +const char* SIMDConfig::level_names[] = + {"NONE", "AVX2", "AVX512F", "ARM_NEON", "ARM_SVE", "PPC"}; + +// it is there to make sure the constructor runs +static SIMDConfig dummy_config; + +SIMDConfig::SIMDConfig() { + char* env_var = getenv("FAISS_SIMD_LEVEL"); + if (env_var) { + int i; + for (i = 0; i <= sizeof(level_names); i++) { + if (strcmp(env_var, level_names[i]) == 0) { + level = (SIMDLevel)i; + break; + } + } + FAISS_THROW_IF_NOT_FMT( + i != sizeof(level_names), + "FAISS_SIMD_LEVEL %s unknown", + env_var); + return; + } + +#ifdef __x86_64__ + { + unsigned int eax, ebx, ecx, edx; + asm volatile("cpuid" + : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx) + : "a"(0)); + +#ifdef COMPILE_SIMD_AVX512F + if (ebx & (1 << 16)) { + level = SIMDLevel::AVX512F; + } else +#endif + +#ifdef COMPILE_SIMD_AVX2 + if (ecx & 32) { + level = SIMDLevel::AVX2; + } else +#endif + level = SIMDLevel::NONE; + } +#endif +} + +void SIMDConfig::set_level(SIMDLevel l) { + level = l; + // this could be used to set function pointers in the future +} + +SIMDLevel SIMDConfig::get_level() { + return level; +} + +std::string SIMDConfig::get_level_name() { + return std::string(level_names[int(level)]); +} + +} // namespace faiss diff --git a/faiss/utils/simd_levels.h b/faiss/utils/simd_levels.h new file mode 100644 index 0000000000..6519151305 --- /dev/null +++ b/faiss/utils/simd_levels.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace faiss { + +/* SIMD levels, used as template parameters. All levels are defined, even for + * architectures different of the current one. */ +enum class SIMDLevel { + NONE, + // x86 + AVX2, + AVX512F, + // arm + ARM_NEON, + ARM_SVE, + // ppc + PPC_ALTIVEC, +}; + +/* Current SIMD configuration. This static class manages the current SIMD level + * and intializes it from the cpuid and the FAISS_SIMD_LEVEL + * environment variable */ +struct SIMDConfig { + static SIMDLevel level; + static void set_level(SIMDLevel level); + static SIMDLevel get_level(); + static std::string get_level_name(); + static const char* level_names[]; + + SIMDConfig(); +}; + +/*********************** x86 SIMD */ + +#ifdef COMPILE_SIMD_AVX2 +#define DISPATCH_SIMDLevel_AVX2(f, ...) \ + case SIMDLevel::AVX2: \ + return f(__VA_ARGS__) +#else +#define DISPATCH_SIMDLevel_AVX2(f, ...) +#endif + +#ifdef COMPILE_SIMD_AVX512F +#define DISPATCH_SIMDLevel_AVX512F(f, ...) \ + case SIMDLevel::AVX512F: \ + return f(__VA_ARGS__) +#else +#define DISPATCH_SIMDLevel_AVX512F(f, ...) +#endif + +/*********************** ARM SIMD */ + +#ifdef COMPILE_SIMD_NEON +#define DISPATCH_SIMDLevel_ARM_NEON(f, ...) \ + case SIMDLevel::ARM_NEON: \ + return f(__VA_ARGS__) +#else +#define DISPATCH_SIMDLevel_ARM_NEON(f, ...) +#endif + +#ifdef COMPILE_SIMD_SVE +#define DISPATCH_SIMDLevel_ARM_SVE(f, ...) \ + case SIMDLevel::ARM_SVE: \ + return f(__VA_ARGS__) +#else +#define DISPATCH_SIMDLevel_ARM_SVE(f, ...) +#endif + +/* dispatch function f to f */ + +#define DISPATCH_SIMDLevel(f, ...) \ + switch (SIMDConfig::level) { \ + case SIMDLevel::NONE: \ + return f(__VA_ARGS__); \ + DISPATCH_SIMDLevel_AVX2(f, __VA_ARGS__); \ + DISPATCH_SIMDLevel_AVX512F(f, __VA_ARGS__); \ + DISPATCH_SIMDLevel_ARM_NEON(f, __VA_ARGS__); \ + DISPATCH_SIMDLevel_ARM_SVE(f, __VA_ARGS__); \ + default: \ + FAISS_ASSERT(!"invlalid SIMD level"); \ + } + +} // namespace faiss diff --git a/faiss/utils/simdlib.h b/faiss/utils/simdlib.h index eadfb78ae3..2b8bef4716 100644 --- a/faiss/utils/simdlib.h +++ b/faiss/utils/simdlib.h @@ -16,25 +16,25 @@ #if defined(__AVX512F__) -#include -#include +#include +#include #elif defined(__AVX2__) -#include +#include #elif defined(__aarch64__) -#include +#include #elif defined(__PPC64__) -#include +#include #else // emulated = all operations are implemented as scalars -#include +#include // FIXME: make a SSE version // is this ever going to happen? We will probably rather implement AVX512 diff --git a/faiss/utils/utils.cpp b/faiss/utils/utils.cpp index 0811cb9030..f7172e5fb6 100644 --- a/faiss/utils/utils.cpp +++ b/faiss/utils/utils.cpp @@ -115,16 +115,16 @@ std::string get_compile_options() { options += "OPTIMIZE "; #endif -#ifdef __AVX512F__ - options += "AVX512 "; -#elif defined(__AVX2__) +#ifdef COMPILE_SIMD_AVX2 options += "AVX2 "; -#elif defined(__ARM_FEATURE_SVE) - options += "SVE NEON "; -#elif defined(__aarch64__) +#endif + +#ifdef COMPILE_SIMD_AVX512F + options += "AVX512F "; +#endif + +#ifdef COMPILE_SIMD_NEON options += "NEON "; -#else - options += "GENERIC "; #endif options += gpu_compile_options; diff --git a/tests/test_code_distance.cpp b/tests/test_code_distance.cpp index f1a3939388..7a3a5133fc 100644 --- a/tests/test_code_distance.cpp +++ b/tests/test_code_distance.cpp @@ -21,7 +21,7 @@ #include #include -#include +#include size_t nMismatches( const std::vector& ref, diff --git a/tests/test_pq_encoding.cpp b/tests/test_pq_encoding.cpp index ad5a089883..6335676af5 100644 --- a/tests/test_pq_encoding.cpp +++ b/tests/test_pq_encoding.cpp @@ -13,7 +13,7 @@ #include #include -#include +#include namespace {