diff --git a/faiss/IndexIVFPQ.cpp b/faiss/IndexIVFPQ.cpp index ce1f5ddf1a..c581de791f 100644 --- a/faiss/IndexIVFPQ.cpp +++ b/faiss/IndexIVFPQ.cpp @@ -805,8 +805,9 @@ struct RangeSearchResults { * The scanning functions call their favorite precompute_* * function to precompute the tables they need. *****************************************************/ -template +template struct IVFPQScannerT : QueryTables { + using PQDecoder = typename PQCodeDistance::PQDecoder; const uint8_t* list_codes; const IDType* list_ids; size_t list_size; @@ -882,7 +883,7 @@ struct IVFPQScannerT : QueryTables { float distance_1 = 0; float distance_2 = 0; float distance_3 = 0; - distance_four_codes( + PQCodeDistance::distance_four_codes( pq.M, pq.nbits, sim_table, @@ -905,7 +906,7 @@ struct IVFPQScannerT : QueryTables { if (counter >= 1) { float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -914,7 +915,7 @@ struct IVFPQScannerT : QueryTables { } if (counter >= 2) { float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -923,7 +924,7 @@ struct IVFPQScannerT : QueryTables { } if (counter >= 3) { float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -1089,7 +1090,7 @@ struct IVFPQScannerT : QueryTables { float distance_1 = dis0; float distance_2 = dis0; float distance_3 = dis0; - distance_four_codes( + PQCodeDistance::distance_four_codes( pq.M, pq.nbits, sim_table, @@ -1120,7 +1121,7 @@ struct IVFPQScannerT : QueryTables { n_hamming_pass++; float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -1140,7 +1141,7 @@ struct IVFPQScannerT : QueryTables { n_hamming_pass++; float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -1185,8 +1186,8 @@ struct IVFPQScannerT : QueryTables { * * use_sel: store or ignore the IDSelector */ -template -struct IVFPQScanner : IVFPQScannerT, +template +struct IVFPQScanner : IVFPQScannerT, InvertedListScanner { int precompute_mode; const IDSelector* sel; @@ -1196,7 +1197,7 @@ struct IVFPQScanner : IVFPQScannerT, bool store_pairs, int precompute_mode, const IDSelector* sel) - : IVFPQScannerT(ivfpq, nullptr), + : IVFPQScannerT(ivfpq, nullptr), precompute_mode(precompute_mode), sel(sel) { this->store_pairs = store_pairs; @@ -1215,7 +1216,7 @@ struct IVFPQScanner : IVFPQScannerT, float distance_to_code(const uint8_t* code) const override { assert(precompute_mode == 2); float dis = this->dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( this->pq.M, this->pq.nbits, this->sim_table, code); return dis; } @@ -1279,7 +1280,9 @@ struct IVFPQScanner : IVFPQScannerT, } }; -template +/** follow 3 stages of template dispatching */ + +template InvertedListScanner* get_InvertedListScanner1( const IndexIVFPQ& index, bool store_pairs, @@ -1288,32 +1291,47 @@ InvertedListScanner* get_InvertedListScanner1( return new IVFPQScanner< METRIC_INNER_PRODUCT, CMin, - PQDecoder, + PQCodeDistance, use_sel>(index, store_pairs, 2, sel); } else if (index.metric_type == METRIC_L2) { return new IVFPQScanner< METRIC_L2, CMax, - PQDecoder, + PQCodeDistance, use_sel>(index, store_pairs, 2, sel); } return nullptr; } -template +template InvertedListScanner* get_InvertedListScanner2( const IndexIVFPQ& index, bool store_pairs, const IDSelector* sel) { if (index.pq.nbits == 8) { - return get_InvertedListScanner1( - index, store_pairs, sel); + return get_InvertedListScanner1< + PQCodeDistance, + use_sel>(index, store_pairs, sel); } else if (index.pq.nbits == 16) { - return get_InvertedListScanner1( - index, store_pairs, sel); + return get_InvertedListScanner1< + PQCodeDistance, + use_sel>(index, store_pairs, sel); + } else { + return get_InvertedListScanner1< + PQCodeDistance, + use_sel>(index, store_pairs, sel); + } +} + +template +InvertedListScanner* get_InvertedListScanner3( + const IndexIVFPQ& index, + bool store_pairs, + const IDSelector* sel) { + if (sel) { + return get_InvertedListScanner2(index, store_pairs, sel); } else { - return get_InvertedListScanner1( - index, store_pairs, sel); + return get_InvertedListScanner2(index, store_pairs, sel); } } @@ -1323,11 +1341,7 @@ InvertedListScanner* IndexIVFPQ::get_InvertedListScanner( bool store_pairs, const IDSelector* sel, const IVFSearchParameters*) const { - if (sel) { - return get_InvertedListScanner2(*this, store_pairs, sel); - } else { - return get_InvertedListScanner2(*this, store_pairs, sel); - } + DISPATCH_SIMDLevel(get_InvertedListScanner3, *this, store_pairs, sel); return nullptr; } diff --git a/faiss/IndexPQ.cpp b/faiss/IndexPQ.cpp index 8193e78b17..99347f9c8c 100644 --- a/faiss/IndexPQ.cpp +++ b/faiss/IndexPQ.cpp @@ -73,7 +73,7 @@ void IndexPQ::train(idx_t n, const float* x) { namespace { -template +template struct PQDistanceComputer : FlatCodesDistanceComputer { size_t d; MetricType metric; @@ -86,7 +86,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer { float distance_to_code(const uint8_t* code) final { ndis++; - float dis = distance_single_code( + float dis = PQCodeDistance::distance_single_code( pq.M, pq.nbits, precomputed_table.data(), code); return dis; } @@ -95,8 +95,10 @@ struct PQDistanceComputer : FlatCodesDistanceComputer { FAISS_THROW_IF_NOT(sdc); const float* sdci = sdc; float accu = 0; - PQDecoder codei(codes + i * code_size, pq.nbits); - PQDecoder codej(codes + j * code_size, pq.nbits); + typename PQCodeDistance::PQDecoder codei( + codes + i * code_size, pq.nbits); + typename PQCodeDistance::PQDecoder codej( + codes + j * code_size, pq.nbits); for (int l = 0; l < pq.M; l++) { accu += sdci[codei.decode() + (codej.decode() << codei.nbits)]; @@ -132,16 +134,24 @@ struct PQDistanceComputer : FlatCodesDistanceComputer { } }; +template +FlatCodesDistanceComputer* get_FlatCodesDistanceComputer1( + const IndexPQ& index) { + int nbits = index.pq.nbits; + if (nbits == 8) { + return new PQDistanceComputer>(index); + } else if (nbits == 16) { + return new PQDistanceComputer>(index); + } else { + return new PQDistanceComputer>( + index); + } +} + } // namespace FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const { - if (pq.nbits == 8) { - return new PQDistanceComputer(*this); - } else if (pq.nbits == 16) { - return new PQDistanceComputer(*this); - } else { - return new PQDistanceComputer(*this); - } + DISPATCH_SIMDLevel(get_FlatCodesDistanceComputer1, *this); } /***************************************** diff --git a/faiss/impl/code_distance/code_distance-avx2.cpp b/faiss/impl/code_distance/code_distance-avx2.cpp new file mode 100644 index 0000000000..b140b9ece0 --- /dev/null +++ b/faiss/impl/code_distance/code_distance-avx2.cpp @@ -0,0 +1,486 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=78782 +#if defined(__GNUC__) && __GNUC__ < 9 +#define _mm_loadu_si64(x) (_mm_loadl_epi64((__m128i_u*)x)) +#endif + +namespace { + +inline float horizontal_sum(const __m128 v) { + const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); + const __m128 v1 = _mm_add_ps(v, v0); + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + const __m128 v3 = _mm_add_ps(v1, v2); + return _mm_cvtss_f32(v3); +} + +// Computes a horizontal sum over an __m256 register +inline float horizontal_sum(const __m256 v) { + const __m128 v0 = + _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + return horizontal_sum(v0); +} + +// processes a single code for M=4, ksub=256, nbits=8 +float inline distance_single_code_avx2_pqdecoder8_m4( + // precomputed distances, layout (4, 256) + const float* sim_table, + const uint8_t* code) { + float result = 0; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + const __m128i vksub = _mm_set1_epi32(ksub); + __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); + offsets_0 = _mm_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m128 partialSum; + + // load 4 uint8 values + const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code)); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m128i idx1 = _mm_cvtepu8_epi32(mm1); + + // add offsets + const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m128 collected = + _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSum = collected; + } + + // horizontal sum for partialSum + result = horizontal_sum(partialSum); + return result; +} + +// processes a single code for M=8, ksub=256, nbits=8 +float inline distance_single_code_avx2_pqdecoder8_m8( + // precomputed distances, layout (8, 256) + const float* sim_table, + const uint8_t* code) { + float result = 0; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSum; + + // load 8 uint8 values + const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); + + // add offsets + const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = + _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSum = collected; + } + + // horizontal sum for partialSum + result = horizontal_sum(partialSum); + return result; +} + +// processes four codes for M=4, ksub=256, nbits=8 +inline void distance_four_codes_avx2_pqdecoder8_m4( + // precomputed distances, layout (4, 256) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + constexpr intptr_t N = 4; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + // process 8 values + const __m128i vksub = _mm_set1_epi32(ksub); + __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); + offsets_0 = _mm_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m128 partialSums[N]; + + // load 4 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0)); + mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1)); + mm1[2] = _mm_cvtsi32_si128(*((const int32_t*)code2)); + mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3)); + + for (intptr_t j = 0; j < N; j++) { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); + + // gather 4 values, similar to 4 operations of tab[idx] + __m128 collected = + _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = collected; + } + + // horizontal sum for partialSum + result0 = horizontal_sum(partialSums[0]); + result1 = horizontal_sum(partialSums[1]); + result2 = horizontal_sum(partialSums[2]); + result3 = horizontal_sum(partialSums[3]); +} + +// processes four codes for M=8, ksub=256, nbits=8 +inline void distance_four_codes_avx2_pqdecoder8_m8( + // precomputed distances, layout (8, 256) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + constexpr intptr_t N = 4; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + // process 8 values + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSums[N]; + + // load 8 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si64((const __m128i_u*)code0); + mm1[1] = _mm_loadu_si64((const __m128i_u*)code1); + mm1[2] = _mm_loadu_si64((const __m128i_u*)code2); + mm1[3] = _mm_loadu_si64((const __m128i_u*)code3); + + for (intptr_t j = 0; j < N; j++) { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = + _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = collected; + } + + // horizontal sum for partialSum + result0 = horizontal_sum(partialSums[0]); + result1 = horizontal_sum(partialSums[1]); + result2 = horizontal_sum(partialSums[2]); + result3 = horizontal_sum(partialSums[3]); +} + +} // namespace + +namespace faiss { + +template <> +struct PQCodeDistance { + float distance_single_code( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + const uint8_t* code) { + if (M == 4) { + return distance_single_code_avx2_pqdecoder8_m4(sim_table, code); + } + if (M == 8) { + return distance_single_code_avx2_pqdecoder8_m8(sim_table, code); + } + + float result = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSum = _mm256_setzero_ps(); + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + const __m128i mm1 = + _mm_loadu_si128((const __m128i_u*)(code + m)); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + tab += ksub * 8; + + // collect partial sums + partialSum = _mm256_add_ps(partialSum, collected); + } + + // move high 8 uint8 to low ones + const __m128i mm2 = + _mm_unpackhi_epi64(mm1, _mm_setzero_si128()); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + tab += ksub * 8; + + // collect partial sums + partialSum = _mm256_add_ps(partialSum, collected); + } + } + + // horizontal sum for partialSum + result += horizontal_sum(partialSum); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder(code + m, nbits); + + for (; m < M; m++) { + result += tab[decoder.decode()]; + tab += ksub; + } + } + + return result; + } + + // Combines 4 operations of distance_single_code() + void distance_four_codes( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + if (M == 4) { + distance_four_codes_avx2_pqdecoder8_m4( + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); + return; + } + if (M == 8) { + distance_four_codes_avx2_pqdecoder8_m8( + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); + return; + } + + result0 = 0; + result1 = 0; + result2 = 0; + result3 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 4; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm256_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); + mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); + mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm256_add_ps(partialSums[j], collected); + } + tab += ksub * 8; + + // process next 8 codes + for (intptr_t j = 0; j < N; j++) { + // move high 8 uint8 to low ones + const __m128i mm2 = + _mm_unpackhi_epi64(mm1[j], _mm_setzero_si128()); + + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm256_add_ps(partialSums[j], collected); + } + + tab += ksub * 8; + } + + // horizontal sum for partialSum + result0 += horizontal_sum(partialSums[0]); + result1 += horizontal_sum(partialSums[1]); + result2 += horizontal_sum(partialSums[2]); + result3 += horizontal_sum(partialSums[3]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + PQDecoder8 decoder1(code1 + m, nbits); + PQDecoder8 decoder2(code2 + m, nbits); + PQDecoder8 decoder3(code3 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + result1 += tab[decoder1.decode()]; + result2 += tab[decoder2.decode()]; + result3 += tab[decoder3.decode()]; + tab += ksub; + } + } + } +}; + +// explicit template instanciations +// template struct PQCodeDistance; + +// these two will automatically use the generic implementation +template struct PQCodeDistance; +template struct PQCodeDistance; + +} // namespace faiss diff --git a/faiss/impl/code_distance/code_distance-avx2.h b/faiss/impl/code_distance/code_distance-avx2.h deleted file mode 100644 index 53380b6e46..0000000000 --- a/faiss/impl/code_distance/code_distance-avx2.h +++ /dev/null @@ -1,534 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#ifdef __AVX2__ - -#include - -#include - -#include -#include - -// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=78782 -#if defined(__GNUC__) && __GNUC__ < 9 -#define _mm_loadu_si64(x) (_mm_loadl_epi64((__m128i_u*)x)) -#endif - -namespace { - -inline float horizontal_sum(const __m128 v) { - const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); - const __m128 v1 = _mm_add_ps(v, v0); - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - const __m128 v3 = _mm_add_ps(v1, v2); - return _mm_cvtss_f32(v3); -} - -// Computes a horizontal sum over an __m256 register -inline float horizontal_sum(const __m256 v) { - const __m128 v0 = - _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); - return horizontal_sum(v0); -} - -// processes a single code for M=4, ksub=256, nbits=8 -float inline distance_single_code_avx2_pqdecoder8_m4( - // precomputed distances, layout (4, 256) - const float* sim_table, - const uint8_t* code) { - float result = 0; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - const __m128i vksub = _mm_set1_epi32(ksub); - __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); - offsets_0 = _mm_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m128 partialSum; - - // load 4 uint8 values - const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code)); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m128i idx1 = _mm_cvtepu8_epi32(mm1); - - // add offsets - const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m128 collected = - _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSum = collected; - } - - // horizontal sum for partialSum - result = horizontal_sum(partialSum); - return result; -} - -// processes a single code for M=8, ksub=256, nbits=8 -float inline distance_single_code_avx2_pqdecoder8_m8( - // precomputed distances, layout (8, 256) - const float* sim_table, - const uint8_t* code) { - float result = 0; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSum; - - // load 8 uint8 values - const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); - - // add offsets - const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = - _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSum = collected; - } - - // horizontal sum for partialSum - result = horizontal_sum(partialSum); - return result; -} - -// processes four codes for M=4, ksub=256, nbits=8 -inline void distance_four_codes_avx2_pqdecoder8_m4( - // precomputed distances, layout (4, 256) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - constexpr intptr_t N = 4; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - // process 8 values - const __m128i vksub = _mm_set1_epi32(ksub); - __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); - offsets_0 = _mm_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m128 partialSums[N]; - - // load 4 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0)); - mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1)); - mm1[2] = _mm_cvtsi32_si128(*((const int32_t*)code2)); - mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3)); - - for (intptr_t j = 0; j < N; j++) { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); - - // gather 4 values, similar to 4 operations of tab[idx] - __m128 collected = - _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = collected; - } - - // horizontal sum for partialSum - result0 = horizontal_sum(partialSums[0]); - result1 = horizontal_sum(partialSums[1]); - result2 = horizontal_sum(partialSums[2]); - result3 = horizontal_sum(partialSums[3]); -} - -// processes four codes for M=8, ksub=256, nbits=8 -inline void distance_four_codes_avx2_pqdecoder8_m8( - // precomputed distances, layout (8, 256) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - constexpr intptr_t N = 4; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - // process 8 values - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSums[N]; - - // load 8 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si64((const __m128i_u*)code0); - mm1[1] = _mm_loadu_si64((const __m128i_u*)code1); - mm1[2] = _mm_loadu_si64((const __m128i_u*)code2); - mm1[3] = _mm_loadu_si64((const __m128i_u*)code3); - - for (intptr_t j = 0; j < N; j++) { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = - _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = collected; - } - - // horizontal sum for partialSum - result0 = horizontal_sum(partialSums[0]); - result1 = horizontal_sum(partialSums[1]); - result2 = horizontal_sum(partialSums[2]); - result3 = horizontal_sum(partialSums[3]); -} - -} // namespace - -namespace faiss { - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code) { - // default implementation - return distance_single_code_generic(M, nbits, sim_table, code); -} - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code) { - if (M == 4) { - return distance_single_code_avx2_pqdecoder8_m4(sim_table, code); - } - if (M == 8) { - return distance_single_code_avx2_pqdecoder8_m8(sim_table, code); - } - - float result = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSum = _mm256_setzero_ps(); - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - const __m128i mm1 = _mm_loadu_si128((const __m128i_u*)(code + m)); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - tab += ksub * 8; - - // collect partial sums - partialSum = _mm256_add_ps(partialSum, collected); - } - - // move high 8 uint8 to low ones - const __m128i mm2 = _mm_unpackhi_epi64(mm1, _mm_setzero_si128()); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - tab += ksub * 8; - - // collect partial sums - partialSum = _mm256_add_ps(partialSum, collected); - } - } - - // horizontal sum for partialSum - result += horizontal_sum(partialSum); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder(code + m, nbits); - - for (; m < M; m++) { - result += tab[decoder.decode()]; - tab += ksub; - } - } - - return result; -} - -template -typename std::enable_if::value, void>:: - type - distance_four_codes_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_generic( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -// Combines 4 operations of distance_single_code() -template -typename std::enable_if::value, void>::type -distance_four_codes_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - if (M == 4) { - distance_four_codes_avx2_pqdecoder8_m4( - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); - return; - } - if (M == 8) { - distance_four_codes_avx2_pqdecoder8_m8( - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); - return; - } - - result0 = 0; - result1 = 0; - result2 = 0; - result3 = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - constexpr intptr_t N = 4; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSums[N]; - for (intptr_t j = 0; j < N; j++) { - partialSums[j] = _mm256_setzero_ps(); - } - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); - mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); - mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); - mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); - - // process first 8 codes - for (intptr_t j = 0; j < N; j++) { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm256_add_ps(partialSums[j], collected); - } - tab += ksub * 8; - - // process next 8 codes - for (intptr_t j = 0; j < N; j++) { - // move high 8 uint8 to low ones - const __m128i mm2 = - _mm_unpackhi_epi64(mm1[j], _mm_setzero_si128()); - - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm256_add_ps(partialSums[j], collected); - } - - tab += ksub * 8; - } - - // horizontal sum for partialSum - result0 += horizontal_sum(partialSums[0]); - result1 += horizontal_sum(partialSums[1]); - result2 += horizontal_sum(partialSums[2]); - result3 += horizontal_sum(partialSums[3]); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder0(code0 + m, nbits); - PQDecoder8 decoder1(code1 + m, nbits); - PQDecoder8 decoder2(code2 + m, nbits); - PQDecoder8 decoder3(code3 + m, nbits); - for (; m < M; m++) { - result0 += tab[decoder0.decode()]; - result1 += tab[decoder1.decode()]; - result2 += tab[decoder2.decode()]; - result3 += tab[decoder3.decode()]; - tab += ksub; - } - } -} - -} // namespace faiss - -#endif diff --git a/faiss/impl/code_distance/code_distance-avx512.cpp b/faiss/impl/code_distance/code_distance-avx512.cpp new file mode 100644 index 0000000000..891db07975 --- /dev/null +++ b/faiss/impl/code_distance/code_distance-avx512.cpp @@ -0,0 +1,199 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +// According to experiments, the AVX-512 version may be SLOWER than +// the AVX2 version, which is somewhat unexpected. +// This version is not used for now, but it may be used later. +// +// TODO: test for AMD CPUs. + +namespace faiss { + +template <> +struct PQCodeDistance { + float distance_single_code( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + const uint8_t* code0) { + float result0 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 1; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m512i vksub = _mm512_set1_epi32(ksub); + __m512i offsets_0 = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m512 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm512_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m512i indices_to_read_from = + _mm512_add_epi32(idx1, offsets_0); + + // gather 16 values, similar to 16 operations of tab[idx] + __m512 collected = _mm512_i32gather_ps( + indices_to_read_from, tab, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm512_add_ps(partialSums[j], collected); + } + tab += ksub * 16; + } + + // horizontal sum for partialSum + result0 += _mm512_reduce_add_ps(partialSums[0]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + tab += ksub; + } + } + + return result0; + } + + void distance_four_codes_avx512( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + result0 = 0; + result1 = 0; + result2 = 0; + result3 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 4; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m512i vksub = _mm512_set1_epi32(ksub); + __m512i offsets_0 = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m512 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm512_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); + mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); + mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m512i indices_to_read_from = + _mm512_add_epi32(idx1, offsets_0); + + // gather 16 values, similar to 16 operations of tab[idx] + __m512 collected = _mm512_i32gather_ps( + indices_to_read_from, tab, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm512_add_ps(partialSums[j], collected); + } + tab += ksub * 16; + } + + // horizontal sum for partialSum + result0 += _mm512_reduce_add_ps(partialSums[0]); + result1 += _mm512_reduce_add_ps(partialSums[1]); + result2 += _mm512_reduce_add_ps(partialSums[2]); + result3 += _mm512_reduce_add_ps(partialSums[3]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + PQDecoder8 decoder1(code1 + m, nbits); + PQDecoder8 decoder2(code2 + m, nbits); + PQDecoder8 decoder3(code3 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + result1 += tab[decoder1.decode()]; + result2 += tab[decoder2.decode()]; + result3 += tab[decoder3.decode()]; + tab += ksub; + } + } + } +}; + +// explicit template instanciations +// template struct PQCodeDistance; + +// these two will automatically use the generic implementation +template struct PQCodeDistance; +template struct PQCodeDistance; + +} // namespace faiss diff --git a/faiss/impl/code_distance/code_distance-avx512.h b/faiss/impl/code_distance/code_distance-avx512.h deleted file mode 100644 index d05c41c19c..0000000000 --- a/faiss/impl/code_distance/code_distance-avx512.h +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#ifdef __AVX512F__ - -#include - -#include - -#include -#include - -namespace faiss { - -// According to experiments, the AVX-512 version may be SLOWER than -// the AVX2 version, which is somewhat unexpected. -// This version is not used for now, but it may be used later. -// -// TODO: test for AMD CPUs. - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code) { - // default implementation - return distance_single_code_generic(M, nbits, sim_table, code); -} - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code0) { - float result0 = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - constexpr intptr_t N = 1; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - const __m512i vksub = _mm512_set1_epi32(ksub); - __m512i offsets_0 = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m512 partialSums[N]; - for (intptr_t j = 0; j < N; j++) { - partialSums[j] = _mm512_setzero_ps(); - } - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); - - // process first 8 codes - for (intptr_t j = 0; j < N; j++) { - const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m512i indices_to_read_from = - _mm512_add_epi32(idx1, offsets_0); - - // gather 16 values, similar to 16 operations of tab[idx] - __m512 collected = _mm512_i32gather_ps( - indices_to_read_from, tab, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm512_add_ps(partialSums[j], collected); - } - tab += ksub * 16; - } - - // horizontal sum for partialSum - result0 += _mm512_reduce_add_ps(partialSums[0]); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder0(code0 + m, nbits); - for (; m < M; m++) { - result0 += tab[decoder0.decode()]; - tab += ksub; - } - } - - return result0; -} - -template -typename std::enable_if::value, void>:: - type - distance_four_codes_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_generic( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -// Combines 4 operations of distance_single_code() -template -typename std::enable_if::value, void>::type -distance_four_codes_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - result0 = 0; - result1 = 0; - result2 = 0; - result3 = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - constexpr intptr_t N = 4; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - const __m512i vksub = _mm512_set1_epi32(ksub); - __m512i offsets_0 = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m512 partialSums[N]; - for (intptr_t j = 0; j < N; j++) { - partialSums[j] = _mm512_setzero_ps(); - } - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); - mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); - mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); - mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); - - // process first 8 codes - for (intptr_t j = 0; j < N; j++) { - const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m512i indices_to_read_from = - _mm512_add_epi32(idx1, offsets_0); - - // gather 16 values, similar to 16 operations of tab[idx] - __m512 collected = _mm512_i32gather_ps( - indices_to_read_from, tab, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm512_add_ps(partialSums[j], collected); - } - tab += ksub * 16; - } - - // horizontal sum for partialSum - result0 += _mm512_reduce_add_ps(partialSums[0]); - result1 += _mm512_reduce_add_ps(partialSums[1]); - result2 += _mm512_reduce_add_ps(partialSums[2]); - result3 += _mm512_reduce_add_ps(partialSums[3]); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder0(code0 + m, nbits); - PQDecoder8 decoder1(code1 + m, nbits); - PQDecoder8 decoder2(code2 + m, nbits); - PQDecoder8 decoder3(code3 + m, nbits); - for (; m < M; m++) { - result0 += tab[decoder0.decode()]; - result1 += tab[decoder1.decode()]; - result2 += tab[decoder2.decode()]; - result3 += tab[decoder3.decode()]; - tab += ksub; - } - } -} - -} // namespace faiss - -#endif diff --git a/faiss/impl/code_distance/code_distance-generic.cpp b/faiss/impl/code_distance/code_distance-generic.cpp new file mode 100644 index 0000000000..ac9561ed93 --- /dev/null +++ b/faiss/impl/code_distance/code_distance-generic.cpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace faiss { + +// explicit template instanciations +template struct PQCodeDistance; +template struct PQCodeDistance; +template struct PQCodeDistance; + +} // namespace faiss diff --git a/faiss/impl/code_distance/code_distance-generic.h b/faiss/impl/code_distance/code_distance-generic.h deleted file mode 100644 index c02551c415..0000000000 --- a/faiss/impl/code_distance/code_distance-generic.h +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include -#include - -namespace faiss { - -/// Returns the distance to a single code. -template -inline float distance_single_code_generic( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - PQDecoderT decoder(code, nbits); - const size_t ksub = 1 << nbits; - - const float* tab = sim_table; - float result = 0; - - for (size_t m = 0; m < M; m++) { - result += tab[decoder.decode()]; - tab += ksub; - } - - return result; -} - -/// Combines 4 operations of distance_single_code() -/// General-purpose version. -template -inline void distance_four_codes_generic( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - PQDecoderT decoder0(code0, nbits); - PQDecoderT decoder1(code1, nbits); - PQDecoderT decoder2(code2, nbits); - PQDecoderT decoder3(code3, nbits); - const size_t ksub = 1 << nbits; - - const float* tab = sim_table; - result0 = 0; - result1 = 0; - result2 = 0; - result3 = 0; - - for (size_t m = 0; m < M; m++) { - result0 += tab[decoder0.decode()]; - result1 += tab[decoder1.decode()]; - result2 += tab[decoder2.decode()]; - result3 += tab[decoder3.decode()]; - tab += ksub; - } -} - -} // namespace faiss diff --git a/faiss/impl/code_distance/code_distance-sve.h b/faiss/impl/code_distance/code_distance-sve.cpp similarity index 99% rename from faiss/impl/code_distance/code_distance-sve.h rename to faiss/impl/code_distance/code_distance-sve.cpp index 82f7746be6..9a941798ff 100644 --- a/faiss/impl/code_distance/code_distance-sve.h +++ b/faiss/impl/code_distance/code_distance-sve.cpp @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -#pragma once - #ifdef __ARM_FEATURE_SVE #include @@ -15,7 +13,7 @@ #include #include -#include +#include namespace faiss { diff --git a/faiss/impl/code_distance/code_distance.h b/faiss/impl/code_distance/code_distance.h index 8f29abda97..585890cb40 100644 --- a/faiss/impl/code_distance/code_distance.h +++ b/faiss/impl/code_distance/code_distance.h @@ -9,6 +9,10 @@ #include +#include + +#include + // This directory contains functions to compute a distance // from a given PQ code to a query vector, given that the // distances to a query vector for pq.M codebooks are precomputed. @@ -24,163 +28,76 @@ // why the names of the functions for custom implementations // have this _generic or _avx2 suffix. -#ifdef __AVX2__ - -#include - namespace faiss { -template -inline float distance_single_code( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - return distance_single_code_avx2(M, nbits, sim_table, code); -} - -template -inline void distance_four_codes( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_avx2( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} +// definiton and default implementation +template +struct PQCodeDistance { + using PQDecoder = PQDecoderT; + + /// Returns the distance to a single code. + static float distance_single_code( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // the code + const uint8_t* code) { + PQDecoderT decoder(code, nbits); + const size_t ksub = 1 << nbits; + + const float* tab = sim_table; + float result = 0; + + for (size_t m = 0; m < M; m++) { + result += tab[decoder.decode()]; + tab += ksub; + } + + return result; + } + + /// Combines 4 operations of distance_single_code() + /// General-purpose version. + static void distance_four_codes( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + PQDecoderT decoder0(code0, nbits); + PQDecoderT decoder1(code1, nbits); + PQDecoderT decoder2(code2, nbits); + PQDecoderT decoder3(code3, nbits); + const size_t ksub = 1 << nbits; + + const float* tab = sim_table; + result0 = 0; + result1 = 0; + result2 = 0; + result3 = 0; + + for (size_t m = 0; m < M; m++) { + result0 += tab[decoder0.decode()]; + result1 += tab[decoder1.decode()]; + result2 += tab[decoder2.decode()]; + result3 += tab[decoder3.decode()]; + tab += ksub; + } + } +}; } // namespace faiss - -#elif defined(__ARM_FEATURE_SVE) - -#include - -namespace faiss { - -template -inline float distance_single_code( - // the product quantizer - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - return distance_single_code_sve(M, nbits, sim_table, code); -} - -template -inline void distance_four_codes( - // the product quantizer - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_sve( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -} // namespace faiss - -#else - -#include - -namespace faiss { - -template -inline float distance_single_code( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - return distance_single_code_generic(M, nbits, sim_table, code); -} - -template -inline void distance_four_codes( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_generic( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -} // namespace faiss - -#endif diff --git a/faiss/utils/distances.h b/faiss/utils/distances.h index 80d2cfc699..238e4bb40e 100644 --- a/faiss/utils/distances.h +++ b/faiss/utils/distances.h @@ -15,6 +15,7 @@ #include #include +#include namespace faiss { @@ -27,9 +28,15 @@ struct IDSelector; /// Squared L2 distance between two vectors float fvec_L2sqr(const float* x, const float* y, size_t d); +template +float fvec_L2sqr(const float* x, const float* y, size_t d); + /// inner product float fvec_inner_product(const float* x, const float* y, size_t d); +template +float fvec_inner_product(const float* x, const float* y, size_t d); + /// L1 distance float fvec_L1(const float* x, const float* y, size_t d); @@ -138,6 +145,9 @@ size_t fvec_L2sqr_ny_nearest_y_transposed( /** squared norm of a vector */ float fvec_norm_L2sqr(const float* x, size_t d); +template +float fvec_norm_L2sqr(const float* x, size_t d); + /** compute the L2 norms for a set of vectors * * @param norms output norms, size nx @@ -473,6 +483,10 @@ void compute_PQ_dis_tables_dsub2( */ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c); +/* same statically */ +template +void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c); + /** same as fvec_madd, also return index of the min of the result table * @return index of the min of table c */ diff --git a/faiss/utils/distances_simd.cpp b/faiss/utils/distances_simd.cpp index 1990e46aae..e829ae54ec 100644 --- a/faiss/utils/distances_simd.cpp +++ b/faiss/utils/distances_simd.cpp @@ -19,6 +19,9 @@ #include #include +#define AUTOVEC_LEVEL SIMDLevel::NONE +#include + #ifdef __SSE3__ #include #endif @@ -191,43 +194,19 @@ void fvec_inner_products_ny_ref( * Autovectorized implementations */ -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN -float fvec_inner_product(const float* x, const float* y, size_t d) { - float res = 0.F; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i != d; ++i) { - res += x[i] * y[i]; - } - return res; -} -FAISS_PRAGMA_IMPRECISE_FUNCTION_END +// dispatching functions -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN float fvec_norm_L2sqr(const float* x, size_t d) { - // the double in the _ref is suspected to be a typo. Some of the manual - // implementations this replaces used float. - float res = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i != d; ++i) { - res += x[i] * x[i]; - } - - return res; + DISPATCH_SIMDLevel(fvec_norm_L2sqr, x, d); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN float fvec_L2sqr(const float* x, const float* y, size_t d) { - size_t i; - float res = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (i = 0; i < d; i++) { - const float tmp = x[i] - y[i]; - res += tmp * tmp; - } - return res; + DISPATCH_SIMDLevel(fvec_L2sqr, x, y, d); +} + +float fvec_inner_product(const float* x, const float* y, size_t d) { + DISPATCH_SIMDLevel(fvec_inner_product, x, y, d); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END /// Special version of inner product that computes 4 distances /// between x and yi @@ -3246,7 +3225,8 @@ void fvec_inner_products_ny( * heavily optimized table computations ***************************************************************************/ -[[maybe_unused]] static inline void fvec_madd_ref( +template <> +void fvec_madd( size_t n, const float* a, float bf, @@ -3256,94 +3236,6 @@ void fvec_inner_products_ny( c[i] = a[i] + bf * b[i]; } -#if defined(__AVX512F__) - -static inline void fvec_madd_avx512( - const size_t n, - const float* __restrict a, - const float bf, - const float* __restrict b, - float* __restrict c) { - const size_t n16 = n / 16; - const size_t n_for_masking = n % 16; - - const __m512 bfmm = _mm512_set1_ps(bf); - - size_t idx = 0; - for (idx = 0; idx < n16 * 16; idx += 16) { - const __m512 ax = _mm512_loadu_ps(a + idx); - const __m512 bx = _mm512_loadu_ps(b + idx); - const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); - _mm512_storeu_ps(c + idx, abmul); - } - - if (n_for_masking > 0) { - const __mmask16 mask = (1 << n_for_masking) - 1; - - const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx); - const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx); - const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); - _mm512_mask_storeu_ps(c + idx, mask, abmul); - } -} - -#elif defined(__AVX2__) - -static inline void fvec_madd_avx2( - const size_t n, - const float* __restrict a, - const float bf, - const float* __restrict b, - float* __restrict c) { - // - const size_t n8 = n / 8; - const size_t n_for_masking = n % 8; - - const __m256 bfmm = _mm256_set1_ps(bf); - - size_t idx = 0; - for (idx = 0; idx < n8 * 8; idx += 8) { - const __m256 ax = _mm256_loadu_ps(a + idx); - const __m256 bx = _mm256_loadu_ps(b + idx); - const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); - _mm256_storeu_ps(c + idx, abmul); - } - - if (n_for_masking > 0) { - __m256i mask; - switch (n_for_masking) { - case 1: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1); - break; - case 2: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1); - break; - case 3: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1); - break; - case 4: - mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1); - break; - case 5: - mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1); - break; - case 6: - mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1); - break; - case 7: - mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1); - break; - } - - const __m256 ax = _mm256_maskload_ps(a + idx, mask); - const __m256 bx = _mm256_maskload_ps(b + idx, mask); - const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); - _mm256_maskstore_ps(c + idx, mask, abmul); - } -} - -#endif - #ifdef __SSE3__ [[maybe_unused]] static inline void fvec_madd_sse( @@ -3367,16 +3259,7 @@ static inline void fvec_madd_avx2( } void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { -#ifdef __AVX512F__ - fvec_madd_avx512(n, a, bf, b, c); -#elif __AVX2__ - fvec_madd_avx2(n, a, bf, b, c); -#else - if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) - fvec_madd_sse(n, a, bf, b, c); - else - fvec_madd_ref(n, a, bf, b, c); -#endif + DISPATCH_SIMDLevel(fvec_madd, n, a, bf, b, c); } #elif defined(__ARM_FEATURE_SVE) diff --git a/faiss/utils/simd_impl/distances_autovec-inl.h b/faiss/utils/simd_impl/distances_autovec-inl.h new file mode 100644 index 0000000000..ace6dccccb --- /dev/null +++ b/faiss/utils/simd_impl/distances_autovec-inl.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace faiss { + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_norm_L2sqr(const float* x, size_t d) { + // the double in the _ref is suspected to be a typo. Some of the manual + // implementations this replaces used float. + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i != d; ++i) { + res += x[i] * x[i]; + } + + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_L2sqr(const float* x, const float* y, size_t d) { + size_t i; + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (i = 0; i < d; i++) { + const float tmp = x[i] - y[i]; + res += tmp * tmp; + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_inner_product( + const float* x, + const float* y, + size_t d) { + float res = 0.F; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i != d; ++i) { + res += x[i] * y[i]; + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_avx2.cpp b/faiss/utils/simd_impl/distances_avx2.cpp new file mode 100644 index 0000000000..cd52c470f9 --- /dev/null +++ b/faiss/utils/simd_impl/distances_avx2.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#define AUTOVEC_LEVEL SIMDLevel::AVX2 +#include + +namespace faiss { + +template <> +void fvec_madd( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + // + const size_t n8 = n / 8; + const size_t n_for_masking = n % 8; + + const __m256 bfmm = _mm256_set1_ps(bf); + + size_t idx = 0; + for (idx = 0; idx < n8 * 8; idx += 8) { + const __m256 ax = _mm256_loadu_ps(a + idx); + const __m256 bx = _mm256_loadu_ps(b + idx); + const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); + _mm256_storeu_ps(c + idx, abmul); + } + + if (n_for_masking > 0) { + __m256i mask; + switch (n_for_masking) { + case 1: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1); + break; + case 2: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1); + break; + case 3: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1); + break; + case 4: + mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1); + break; + case 5: + mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1); + break; + case 6: + mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1); + break; + case 7: + mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1); + break; + } + + const __m256 ax = _mm256_maskload_ps(a + idx, mask); + const __m256 bx = _mm256_maskload_ps(b + idx, mask); + const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); + _mm256_maskstore_ps(c + idx, mask, abmul); + } +} + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_avx512.cpp b/faiss/utils/simd_impl/distances_avx512.cpp new file mode 100644 index 0000000000..e87df652fd --- /dev/null +++ b/faiss/utils/simd_impl/distances_avx512.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#define AUTOVEC_LEVEL SIMDLevel::AVX512F +#include + +namespace faiss { + +template <> +void fvec_madd( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + const size_t n16 = n / 16; + const size_t n_for_masking = n % 16; + + const __m512 bfmm = _mm512_set1_ps(bf); + + size_t idx = 0; + for (idx = 0; idx < n16 * 16; idx += 16) { + const __m512 ax = _mm512_loadu_ps(a + idx); + const __m512 bx = _mm512_loadu_ps(b + idx); + const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); + _mm512_storeu_ps(c + idx, abmul); + } + + if (n_for_masking > 0) { + const __mmask16 mask = (1 << n_for_masking) - 1; + + const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx); + const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx); + const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); + _mm512_mask_storeu_ps(c + idx, mask, abmul); + } +} + +} // namespace faiss diff --git a/faiss/utils/simdlib_avx2.h b/faiss/utils/simd_impl/simdlib_avx2.h similarity index 100% rename from faiss/utils/simdlib_avx2.h rename to faiss/utils/simd_impl/simdlib_avx2.h diff --git a/faiss/utils/simdlib_avx512.h b/faiss/utils/simd_impl/simdlib_avx512.h similarity index 100% rename from faiss/utils/simdlib_avx512.h rename to faiss/utils/simd_impl/simdlib_avx512.h diff --git a/faiss/utils/simdlib_emulated.h b/faiss/utils/simd_impl/simdlib_emulated.h similarity index 100% rename from faiss/utils/simdlib_emulated.h rename to faiss/utils/simd_impl/simdlib_emulated.h diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simd_impl/simdlib_neon.h similarity index 100% rename from faiss/utils/simdlib_neon.h rename to faiss/utils/simd_impl/simdlib_neon.h diff --git a/faiss/utils/simdlib_ppc64.h b/faiss/utils/simd_impl/simdlib_ppc64.h similarity index 100% rename from faiss/utils/simdlib_ppc64.h rename to faiss/utils/simd_impl/simdlib_ppc64.h diff --git a/faiss/utils/simd_levels.cpp b/faiss/utils/simd_levels.cpp new file mode 100644 index 0000000000..894c9aa271 --- /dev/null +++ b/faiss/utils/simd_levels.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace faiss { + +SIMDLevel SIMDConfig::level = SIMDLevel::NONE; + +const char* SIMDConfig::level_names[] = + {"NONE", "AVX2", "AVX512F", "ARM_NEON", "ARM_SVE", "PPC"}; + +SIMDConfig::SIMDConfig() { + char* env_var = getenv("FAISS_SIMD_LEVEL"); + if (env_var) { + int i; + for (i = 0; i <= sizeof(level_names); i++) { + if (strcmp(env_var, level_names[i]) == 0) { + level = (SIMDLevel)i; + break; + } + } + FAISS_THROW_IF_NOT_FMT( + i != sizeof(level_names), + "FAISS_SIMD_LEVEL %s unknown", + env_var); + return; + } + +#ifdef __x86_64__ + { + unsigned int eax, ebx, ecx, edx; + asm volatile("cpuid" + : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx) + : "a"(0)); + +#ifdef COMPILE_SIMD_AVX512F + if (ebx & (1 << 16)) { + level = AVX512F; + } else +#endif + +#ifdef COMPILE_SIMD_AVX2 + if (ecx & 32) { + level = AVX2; + } else +#endif + level = NONE; + } +#endif +} + +} // namespace faiss diff --git a/faiss/utils/simd_levels.h b/faiss/utils/simd_levels.h new file mode 100644 index 0000000000..db0c8ff98c --- /dev/null +++ b/faiss/utils/simd_levels.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +namespace faiss { + +enum SIMDLevel { + NONE, + AVX2, + AVX512F, + ARM_NEON, + ARM_SVE, + PPC_ALTIVEC, +}; + +struct SIMDConfig { + static SIMDLevel level; + static const char* level_names[]; + // initializes the simd_level from the cpuid and the FAISS_SIMD_LEVEL + // environment variable + SIMDConfig(); +}; + +/*********************** x86 SIMD */ + +#ifdef COMPILE_SIMD_AVX2 +#define DISPATCH_SIMDLevel_AVX2(f, ...) \ + case SIMDLevel::AVX2: \ + return f(__VA_ARGS__) +#else +#define DISPATCH_SIMDLevel_AVX2(f, ...) +#endif + +#ifdef COMPILE_SIMD_AVX512F +#define DISPATCH_SIMDLevel_AVX512F(f, ...) \ + case SIMDLevel::AVX512F: \ + return f(__VA_ARGS__) +#else +#define DISPATCH_SIMDLevel_AVX512F(f, ...) +#endif + +/*********************** ARM SIMD */ + +#ifdef COMPILE_SIMD_NEON +#define DISPATCH_SIMDLevel_ARM_NEON(f, ...) \ + case SIMDLevel::ARM_NEON: \ + return f(__VA_ARGS__) +#else +#define DISPATCH_SIMDLevel_ARM_NEON(f, ...) +#endif + +#ifdef COMPILE_SIMD_SVE +#define DISPATCH_SIMDLevel_ARM_SVE(f, ...) \ + case SIMDLevel::ARM_SVE: \ + return f(__VA_ARGS__) +#else +#define DISPATCH_SIMDLevel_ARM_SVE(f, ...) +#endif + +/* dispatch function f to f */ + +#define DISPATCH_SIMDLevel(f, ...) \ + switch (SIMDConfig::level) { \ + case SIMDLevel::NONE: \ + return f(__VA_ARGS__); \ + DISPATCH_SIMDLevel_AVX2(f, __VA_ARGS__); \ + DISPATCH_SIMDLevel_AVX512F(f, __VA_ARGS__); \ + DISPATCH_SIMDLevel_ARM_NEON(f, __VA_ARGS__); \ + DISPATCH_SIMDLevel_ARM_SVE(f, __VA_ARGS__); \ + default: \ + assert(!"invlalid SIMD level"); \ + } + +} // namespace faiss diff --git a/faiss/utils/simdlib.h b/faiss/utils/simdlib.h index eadfb78ae3..98c38f7a0d 100644 --- a/faiss/utils/simdlib.h +++ b/faiss/utils/simdlib.h @@ -21,20 +21,20 @@ #elif defined(__AVX2__) -#include +#include #elif defined(__aarch64__) -#include +#include #elif defined(__PPC64__) -#include +#include #else // emulated = all operations are implemented as scalars -#include +#include // FIXME: make a SSE version // is this ever going to happen? We will probably rather implement AVX512