From fc4d2e405d5e33c1075d53df578b34c056764bf7 Mon Sep 17 00:00:00 2001 From: matthijs Date: Thu, 22 Jan 2026 12:01:54 -0800 Subject: [PATCH 1/3] SIMDConfig object, [faiss] Adding support for AVX2 and AVX512(F, BW, DQ, Vl, DL) detection Summary: * Added support to detect SIMD instruction set for both `AVX2` and `AVX512F, AVX512VL` related levels * Added hardware specific unit tests (eg: checks when unit tests are ran on x86 arch then relevant SIMD levels are returned, also respective instructions are executed) * Reason for explicitly running computation and not relying on `__builtin_cpu_supports("avx512f")` [link](https://stackoverflow.com/questions/48677575/does-gccs-builtin-cpu-supports-check-for-os-support) * Also, fixes the bug in existing `AVX2` detection * Incorrect CPUID Bit Check: Function uses `ebx & (1 << 16)` to check for `AVX2` support. This is incorrect because bit 16 in `ebx` is actually used for `AVX-512F`, not `AVX2`. * Correct Bit for AVX2: Correct bit for detecting AVX2 is bit 5 in `ebx` when `eax = 7` and `ecx = 0`. This is based on Intel's documentation for the CPUID instruction. * Another bug observed in constructor for SIMDConfig (if env variable is set, the codepath still follows detection via code) * Improving SIMDConfig to take parameters to its constructor to support and enable injection mechanism for better testing* Adding more unit tests for other Hardware * Added variable with SIMDConfig to track all possible supported SIMD Levels Differential Revision: D72937710 Reviewed By: mdouze --- faiss/utils/simd_levels.cpp | 172 ++++++++++++++++++++++ faiss/utils/simd_levels.h | 82 +++++++++++ tests/test_simd_levels.cpp | 280 ++++++++++++++++++++++++++++++++++++ 3 files changed, 534 insertions(+) create mode 100644 faiss/utils/simd_levels.cpp create mode 100644 faiss/utils/simd_levels.h create mode 100644 tests/test_simd_levels.cpp diff --git a/faiss/utils/simd_levels.cpp b/faiss/utils/simd_levels.cpp new file mode 100644 index 0000000000..887225ee3b --- /dev/null +++ b/faiss/utils/simd_levels.cpp @@ -0,0 +1,172 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace faiss { + +SIMDLevel SIMDConfig::level = SIMDLevel::NONE; +std::unordered_set& SIMDConfig::supported_simd_levels() { + static std::unordered_set levels; + return levels; +} + +// it is there to make sure the constructor runs +static SIMDConfig dummy_config; + +SIMDConfig::SIMDConfig(const char** faiss_simd_level_env) { + // added to support dependency injection + const char* env_var = faiss_simd_level_env ? *faiss_simd_level_env + : getenv("FAISS_SIMD_LEVEL"); + + // check environment variable for SIMD level is explicitly set + if (!env_var) { + level = auto_detect_simd_level(); + } else { + auto matched_level = to_simd_level(env_var); + if (matched_level.has_value()) { + set_level(matched_level.value()); + supported_simd_levels().clear(); + supported_simd_levels().insert(matched_level.value()); + } else { + fprintf(stderr, + "FAISS_SIMD_LEVEL is set to %s, which is unknown\n", + env_var); + exit(1); + } + } + supported_simd_levels().insert(SIMDLevel::NONE); +} + +void SIMDConfig::set_level(SIMDLevel l) { + level = l; +} + +SIMDLevel SIMDConfig::get_level() { + return level; +} + +std::string SIMDConfig::get_level_name() { + return to_string(level).value_or(""); +} + +bool SIMDConfig::is_simd_level_available(SIMDLevel l) { + return supported_simd_levels().find(l) != supported_simd_levels().end(); +} + +SIMDLevel SIMDConfig::auto_detect_simd_level() { + SIMDLevel level = SIMDLevel::NONE; + +#if defined(__x86_64__) && \ + (defined(COMPILE_SIMD_AVX2) || defined(COMPILE_SIMD_AVX512)) + unsigned int eax, ebx, ecx, edx; + + eax = 1; + ecx = 0; + asm volatile("cpuid" + : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx) + : "a"(eax), "c"(ecx)); + + bool has_avx = (ecx & (1 << 28)) != 0; + + bool has_xsave_osxsave = + (ecx & ((1 << 26) | (1 << 27))) == ((1 << 26) | (1 << 27)); + + bool avx_supported = false; + if (has_avx && has_xsave_osxsave) { + unsigned int xcr0; + asm volatile("xgetbv" : "=a"(xcr0), "=d"(edx) : "c"(0)); + avx_supported = (xcr0 & 6) == 6; + } + + if (avx_supported) { + eax = 7; + ecx = 0; + asm volatile("cpuid" + : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx) + : "a"(eax), "c"(ecx)); + + unsigned int xcr0; + asm volatile("xgetbv" : "=a"(xcr0), "=d"(edx) : "c"(0)); + +#if defined(COMPILE_SIMD_AVX2) || defined(COMPILE_SIMD_AVX512) + bool has_avx2 = (ebx & (1 << 5)) != 0; + if (has_avx2) { + SIMDConfig::supported_simd_levels().insert(SIMDLevel::AVX2); + level = SIMDLevel::AVX2; + } + +#if defined(COMPILE_SIMD_AVX512) + bool cpu_has_avx512f = (ebx & (1 << 16)) != 0; + bool os_supports_avx512 = (xcr0 & 0xE0) == 0xE0; + bool has_avx512f = cpu_has_avx512f && os_supports_avx512; + if (has_avx512f) { + bool has_avx512cd = (ebx & (1 << 28)) != 0; + bool has_avx512vl = (ebx & (1 << 31)) != 0; + bool has_avx512dq = (ebx & (1 << 17)) != 0; + bool has_avx512bw = (ebx & (1 << 30)) != 0; + if (has_avx512bw && has_avx512cd && has_avx512vl && has_avx512dq) { + level = SIMDLevel::AVX512; + supported_simd_levels().insert(SIMDLevel::AVX512); + } + } +#endif // defined(COMPILE_SIMD_AVX512) +#endif // defined(COMPILE_SIMD_AVX2)|| defined(COMPILE_SIMD_AVX512) + } +#endif // defined(__x86_64__) && (defined(COMPILE_SIMD_AVX2) || + // defined(COMPILE_SIMD_AVX512)) + +#if defined(__aarch64__) && defined(__ARM_NEON) && \ + defined(COMPILE_SIMD_ARM_NEON) + // ARM NEON is standard on aarch64, so we can assume it's available + supported_simd_levels().insert(SIMDLevel::ARM_NEON); + level = SIMDLevel::ARM_NEON; + + // TODO: Add ARM SVE detection when needed + // For now, we default to ARM_NEON as it's universally supported on aarch64 +#endif + + return level; +} + +std::optional to_string(SIMDLevel level) { + switch (level) { + case SIMDLevel::NONE: + return "NONE"; + case SIMDLevel::AVX2: + return "AVX2"; + case SIMDLevel::AVX512: + return "AVX512"; + case SIMDLevel::ARM_NEON: + return "ARM_NEON"; + default: + return std::nullopt; + } + return std::nullopt; +} + +std::optional to_simd_level(const std::string& level_str) { + if (level_str == "NONE") { + return SIMDLevel::NONE; + } + if (level_str == "AVX2") { + return SIMDLevel::AVX2; + } + if (level_str == "AVX512") { + return SIMDLevel::AVX512; + } + if (level_str == "ARM_NEON") { + return SIMDLevel::ARM_NEON; + } + + return std::nullopt; +} + +} // namespace faiss diff --git a/faiss/utils/simd_levels.h b/faiss/utils/simd_levels.h new file mode 100644 index 0000000000..ad3d0b289d --- /dev/null +++ b/faiss/utils/simd_levels.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace faiss { + +#define COMPILE_SIMD_NONE + +enum class SIMDLevel { + NONE, + // x86 + AVX2, + AVX512, + // arm & aarch64 + ARM_NEON, + + COUNT +}; + +std::optional to_string(SIMDLevel level); + +std::optional to_simd_level(const std::string& level_str); + +/* Current SIMD configuration. This static class manages the current SIMD level + * and intializes it from the cpuid and the FAISS_SIMD_LEVEL + * environment variable */ +struct SIMDConfig { + static SIMDLevel level; + static std::unordered_set& supported_simd_levels(); + + typedef SIMDLevel (*DetectSIMDLevelFunc)(); + static SIMDLevel auto_detect_simd_level(); + + SIMDConfig(const char** faiss_simd_level_env = nullptr); + + static void set_level(SIMDLevel level); + static SIMDLevel get_level(); + static std::string get_level_name(); + + static bool is_simd_level_available(SIMDLevel level); +}; + +/*********************** x86 SIMD */ + +#ifdef COMPILE_SIMD_AVX2 +#define DISPATCH_SIMDLevel_AVX2(f, ...) \ + case SIMDLevel::AVX2: \ + return f(__VA_ARGS__) +#else +#define DISPATCH_SIMDLevel_AVX2(f, ...) +#endif + +#ifdef COMPILE_SIMD_AVX512 +#define DISPATCH_SIMDLevel_AVX512(f, ...) \ + case SIMDLevel::AVX512F: \ + return f(__VA_ARGS__) +#else +#define DISPATCH_SIMDLevel_AVX512(f, ...) +#endif + +/* dispatch function f to f */ + +#define DISPATCH_SIMDLevel(f, ...) \ + switch (SIMDConfig::level) { \ + case SIMDLevel::NONE: \ + return f(__VA_ARGS__); \ + DISPATCH_SIMDLevel_AVX2(f, __VA_ARGS__); \ + DISPATCH_SIMDLevel_AVX512(f, __VA_ARGS__); \ + default: \ + FAISS_ASSERT(!"Invalid SIMD level"); \ + } + +} // namespace faiss diff --git a/tests/test_simd_levels.cpp b/tests/test_simd_levels.cpp new file mode 100644 index 0000000000..4dac2e9877 --- /dev/null +++ b/tests/test_simd_levels.cpp @@ -0,0 +1,280 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#ifdef __x86_64__ +#include +#endif + +#include + +static jmp_buf jmpbuf; +static void sigill_handler(int sig) { + longjmp(jmpbuf, 1); +} + +bool try_execute(void (*func)()) { + signal(SIGILL, sigill_handler); + if (setjmp(jmpbuf) == 0) { + func(); + signal(SIGILL, SIG_DFL); + return true; + } else { + signal(SIGILL, SIG_DFL); + return false; + } +} + +#ifdef __x86_64__ +std::vector run_avx2_computation() { + alignas(32) int result[8]; + alignas(32) int input1[8] = {1, 2, 3, 4, 5, 6, 7, 8}; + alignas(32) int input2[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + + __m256i vec1 = _mm256_load_si256(reinterpret_cast<__m256i*>(input1)); + __m256i vec2 = _mm256_load_si256(reinterpret_cast<__m256i*>(input2)); + __m256i vec_result = _mm256_add_epi32(vec1, vec2); + _mm256_store_si256(reinterpret_cast<__m256i*>(result), vec_result); + + return {result, result + 8}; +} + +std::vector run_avx512f_computation() { + alignas(64) long long result[8]; + alignas(64) long long input1[8] = {1, 2, 3, 4, 5, 6, 7, 8}; + alignas(64) long long input2[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + + __m512i vec1 = _mm512_load_si512(reinterpret_cast(input1)); + __m512i vec2 = _mm512_load_si512(reinterpret_cast(input2)); + __m512i vec_result = _mm512_add_epi64(vec1, vec2); + _mm512_store_si512(reinterpret_cast<__m512i*>(result), vec_result); + + return {result, result + 8}; +} + +std::vector run_avx512cd_computation() { + run_avx512f_computation(); + + __m512i indices = _mm512_set_epi32( + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + __m512i conflict_mask = _mm512_conflict_epi32(indices); + + alignas(64) int mask_array[16]; + _mm512_store_epi32(mask_array, conflict_mask); + + return std::vector(); +} + +std::vector run_avx512vl_computation() { + run_avx512f_computation(); + + __m256i vec1 = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + __m256i vec2 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); + __m256i result = _mm256_add_epi32(vec1, vec2); + alignas(32) int result_array[8]; + _mm256_store_si256(reinterpret_cast<__m256i*>(result_array), result); + + return std::vector(result_array, result_array + 8); +} + +std::vector run_avx512dq_computation() { + run_avx512f_computation(); + + __m512i vec1 = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0); + __m512i vec2 = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); + __m512i result = _mm512_add_epi64(vec1, vec2); + + alignas(64) long long result_array[8]; + _mm512_store_si512(result_array, result); + + return std::vector(result_array, result_array + 8); +} + +std::vector run_avx512bw_computation() { + run_avx512f_computation(); + + std::vector input1(64, 0); + __m512i vec1 = + _mm512_loadu_si512(reinterpret_cast(input1.data())); + std::vector input2(64, 7); + __m512i vec2 = + _mm512_loadu_si512(reinterpret_cast(input2.data())); + __m512i result = _mm512_add_epi8(vec1, vec2); + + alignas(64) int8_t result_array[64]; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(result_array), result); + + return std::vector(result_array, result_array + 64); +} +#endif // __x86_64__ + +std::pair> try_execute(std::vector (*func)()) { + signal(SIGILL, sigill_handler); + if (setjmp(jmpbuf) == 0) { + auto result = func(); + signal(SIGILL, SIG_DFL); + return std::make_pair(true, result); + } else { + signal(SIGILL, SIG_DFL); + return std::make_pair(false, std::vector()); + } +} + +TEST(SIMDConfig, simd_level_auto_detect_architecture_only) { + faiss::SIMDLevel detected_level = + faiss::SIMDConfig::auto_detect_simd_level(); + +#if defined(__x86_64__) && \ + (defined(__AVX2__) || \ + (defined(__AVX512F__) && defined(__AVX512CD__) && \ + defined(__AVX512VL__) && defined(__AVX512BW__) && \ + defined(__AVX512DQ__))) + EXPECT_TRUE( + detected_level == faiss::SIMDLevel::AVX2 || + detected_level == faiss::SIMDLevel::AVX512); +#elif defined(__aarch64__) && defined(__ARM_NEON) + EXPECT_TRUE(detected_level == faiss::SIMDLevel::ARM_NEON); +#else + EXPECT_EQ(detected_level, faiss::SIMDLevel::NONE); +#endif +} + +#ifdef __x86_64__ +TEST(SIMDConfig, successful_avx2_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX2)) { + auto actual_result = try_execute(run_avx2_computation); + EXPECT_TRUE(actual_result.first); + auto expected_result_vector = std::vector(8, 9); + EXPECT_EQ(actual_result.second, expected_result_vector); + } +} + +TEST(SIMDConfig, on_avx512f_supported_we_should_avx2_support_as_well) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + EXPECT_TRUE( + simd_config.is_simd_level_available(faiss::SIMDLevel::AVX2)); + } +} + +TEST(SIMDConfig, successful_avx512f_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + auto actual_result = try_execute(run_avx512f_computation); + EXPECT_TRUE(actual_result.first); + auto expected_result_vector = std::vector(8, 9); + EXPECT_EQ(actual_result.second, expected_result_vector); + } +} + +TEST(SIMDConfig, successful_avx512cd_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + auto actual = try_execute(run_avx512cd_computation); + EXPECT_TRUE(actual.first); + } +} + +TEST(SIMDConfig, successful_avx512vl_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + auto actual = try_execute(run_avx512vl_computation); + EXPECT_TRUE(actual.first); + EXPECT_EQ(actual.second, std::vector(8, 7)); + } +} + +TEST(SIMDConfig, successful_avx512dq_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + EXPECT_TRUE( + simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)); + auto actual = try_execute(run_avx512dq_computation); + EXPECT_TRUE(actual.first); + EXPECT_EQ(actual.second, std::vector(8, 7)); + } +} + +TEST(SIMDConfig, successful_avx512bw_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + EXPECT_TRUE( + simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)); + auto actual = try_execute(run_avx512bw_computation); + EXPECT_TRUE(actual.first); + EXPECT_EQ(actual.second, std::vector(64, 7)); + } +} +#endif // __x86_64__ + +TEST(SIMDConfig, override_simd_level) { + const char* faiss_env_var_neon = "ARM_NEON"; + faiss::SIMDConfig simd_neon_config(&faiss_env_var_neon); + EXPECT_EQ(simd_neon_config.level, faiss::SIMDLevel::ARM_NEON); + + EXPECT_EQ(simd_neon_config.supported_simd_levels().size(), 2); + EXPECT_TRUE(simd_neon_config.is_simd_level_available( + faiss::SIMDLevel::ARM_NEON)); + + const char* faiss_env_var_avx512 = "AVX512"; + faiss::SIMDConfig simd_avx512_config(&faiss_env_var_avx512); + EXPECT_EQ(simd_avx512_config.level, faiss::SIMDLevel::AVX512); + EXPECT_EQ(simd_avx512_config.supported_simd_levels().size(), 2); + EXPECT_TRUE(simd_avx512_config.is_simd_level_available( + faiss::SIMDLevel::AVX512)); +} + +TEST(SIMDConfig, simd_config_get_level_name) { + const char* faiss_env_var_neon = "ARM_NEON"; + faiss::SIMDConfig simd_neon_config(&faiss_env_var_neon); + EXPECT_EQ(simd_neon_config.level, faiss::SIMDLevel::ARM_NEON); + EXPECT_TRUE(simd_neon_config.is_simd_level_available( + faiss::SIMDLevel::ARM_NEON)); + EXPECT_EQ(faiss_env_var_neon, simd_neon_config.get_level_name()); + + const char* faiss_env_var_avx512 = "AVX512"; + faiss::SIMDConfig simd_avx512_config(&faiss_env_var_avx512); + EXPECT_EQ(simd_avx512_config.level, faiss::SIMDLevel::AVX512); + EXPECT_TRUE(simd_avx512_config.is_simd_level_available( + faiss::SIMDLevel::AVX512)); + EXPECT_EQ(faiss_env_var_avx512, simd_avx512_config.get_level_name()); +} + +TEST(SIMDLevel, get_level_name_from_enum) { + EXPECT_EQ("NONE", to_string(faiss::SIMDLevel::NONE).value_or("")); + EXPECT_EQ("AVX2", to_string(faiss::SIMDLevel::AVX2).value_or("")); + EXPECT_EQ("AVX512", to_string(faiss::SIMDLevel::AVX512).value_or("")); + EXPECT_EQ("ARM_NEON", to_string(faiss::SIMDLevel::ARM_NEON).value_or("")); + + int actual_num_simd_levels = static_cast(faiss::SIMDLevel::COUNT); + EXPECT_EQ(4, actual_num_simd_levels); + // Check that all SIMD levels have a name (except for COUNT which is not a + // real SIMD level) + for (int i = 0; i < actual_num_simd_levels - 1; ++i) { + faiss::SIMDLevel simd_level = static_cast(i); + EXPECT_TRUE(faiss::to_string(simd_level).has_value()); + } +} + +TEST(SIMDLevel, to_simd_level_from_string) { + EXPECT_EQ(faiss::SIMDLevel::NONE, faiss::to_simd_level("NONE")); + EXPECT_EQ(faiss::SIMDLevel::AVX2, faiss::to_simd_level("AVX2")); + EXPECT_EQ(faiss::SIMDLevel::AVX512, faiss::to_simd_level("AVX512")); + EXPECT_EQ(faiss::SIMDLevel::ARM_NEON, faiss::to_simd_level("ARM_NEON")); + EXPECT_FALSE(faiss::to_simd_level("INVALID").has_value()); +} From 687818671f39ffc41f35c3fd2755d989f39b6cee Mon Sep 17 00:00:00 2001 From: matthijs Date: Fri, 23 Jan 2026 03:12:57 -0800 Subject: [PATCH 2/3] dynamic dispatch distances_simd Summary: `fvec_madd` is the first function to test dispatching to AVX and AVX512 distances_simd.cpp is split into specialized files distances_avx2.cpp distances_avx512.cpp that are compiled with appropriate flags. Differential Revision: D72937708 Reviewed By: mnorris11 --- faiss/utils/distances.h | 98 + faiss/utils/distances_simd.cpp | 3589 +---------------- faiss/utils/extra_distances-inl.h | 7 - faiss/utils/hamming_distance/generic-inl.h | 1 - faiss/utils/simd_impl/distances_aarch64.cpp | 137 + faiss/utils/simd_impl/distances_arm_sve.cpp | 496 +++ faiss/utils/simd_impl/distances_autovec-inl.h | 153 + faiss/utils/simd_impl/distances_avx.cpp | 99 + faiss/utils/simd_impl/distances_avx2.cpp | 1178 ++++++ faiss/utils/simd_impl/distances_avx512.cpp | 1092 +++++ faiss/utils/simd_impl/distances_sse-inl.h | 385 ++ faiss/utils/simd_levels.cpp | 3 +- faiss/utils/simd_levels.h | 2 +- tests/test_distances_simd.cpp | 532 ++- tests/test_simd_levels.cpp | 156 +- tests/test_simd_perf.cpp | 184 + 16 files changed, 4381 insertions(+), 3731 deletions(-) create mode 100644 faiss/utils/simd_impl/distances_aarch64.cpp create mode 100644 faiss/utils/simd_impl/distances_arm_sve.cpp create mode 100644 faiss/utils/simd_impl/distances_autovec-inl.h create mode 100644 faiss/utils/simd_impl/distances_avx.cpp create mode 100644 faiss/utils/simd_impl/distances_avx2.cpp create mode 100644 faiss/utils/simd_impl/distances_avx512.cpp create mode 100644 faiss/utils/simd_impl/distances_sse-inl.h create mode 100644 tests/test_simd_perf.cpp diff --git a/faiss/utils/distances.h b/faiss/utils/distances.h index 3b6c3d5e1c..e2b7d4d608 100644 --- a/faiss/utils/distances.h +++ b/faiss/utils/distances.h @@ -15,6 +15,7 @@ #include #include +#include namespace faiss { @@ -27,15 +28,27 @@ struct IDSelector; /// Squared L2 distance between two vectors float fvec_L2sqr(const float* x, const float* y, size_t d); +template +float fvec_L2sqr(const float* x, const float* y, size_t d); + /// inner product float fvec_inner_product(const float* x, const float* y, size_t d); +template +float fvec_inner_product(const float* x, const float* y, size_t d); + /// L1 distance float fvec_L1(const float* x, const float* y, size_t d); +template +float fvec_L1(const float* x, const float* y, size_t d); + /// infinity distance float fvec_Linf(const float* x, const float* y, size_t d); +template +float fvec_Linf(const float* x, const float* y, size_t d); + /// Special version of inner product that computes 4 distances /// between x and yi, which is performance oriented. void fvec_inner_product_batch_4( @@ -50,6 +63,19 @@ void fvec_inner_product_batch_4( float& dis2, float& dis3); +template +void fvec_inner_product_batch_4( + const float* x, + const float* y0, + const float* y1, + const float* y2, + const float* y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3); + /// Special version of L2sqr that computes 4 distances /// between x and yi, which is performance oriented. void fvec_L2sqr_batch_4( @@ -64,6 +90,19 @@ void fvec_L2sqr_batch_4( float& dis2, float& dis3); +template +void fvec_L2sqr_batch_4( + const float* x, + const float* y0, + const float* y1, + const float* y2, + const float* y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3); + /** Compute pairwise distances between sets of vectors * * @param d dimension of the vectors @@ -93,6 +132,14 @@ void fvec_inner_products_ny( size_t d, size_t ny); +template +void fvec_inner_products_ny( + float* ip, /* output inner product */ + const float* x, + const float* y, + size_t d, + size_t ny); + /* compute ny square L2 distance between x and a set of contiguous y vectors */ void fvec_L2sqr_ny( float* dis, @@ -101,6 +148,14 @@ void fvec_L2sqr_ny( size_t d, size_t ny); +template +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny); + /* compute ny square L2 distance between x and a set of transposed contiguous y vectors. squared lengths of y should be provided as well */ void fvec_L2sqr_ny_transposed( @@ -112,6 +167,16 @@ void fvec_L2sqr_ny_transposed( size_t d_offset, size_t ny); +template +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny); + /* compute ny square L2 distance between x and a set of contiguous y vectors and return the index of the nearest vector. return 0 if ny == 0. */ @@ -122,6 +187,14 @@ size_t fvec_L2sqr_ny_nearest( size_t d, size_t ny); +template +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny); + /* compute ny square L2 distance between x and a set of transposed contiguous y vectors and return the index of the nearest vector. squared lengths of y should be provided as well @@ -135,9 +208,22 @@ size_t fvec_L2sqr_ny_nearest_y_transposed( size_t d_offset, size_t ny); +template +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny); + /** squared norm of a vector */ float fvec_norm_L2sqr(const float* x, size_t d); +template +float fvec_norm_L2sqr(const float* x, size_t d); + /** compute the L2 norms for a set of vectors * * @param norms output norms, size nx @@ -473,6 +559,10 @@ void compute_PQ_dis_tables_dsub2( */ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c); +/* same statically */ +template +void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c); + /** same as fvec_madd, also return index of the min of the result table * @return index of the min of table c */ @@ -483,4 +573,12 @@ int fvec_madd_and_argmin( const float* b, float* c); +template +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c); + } // namespace faiss diff --git a/faiss/utils/distances_simd.cpp b/faiss/utils/distances_simd.cpp index c6ff8b57cb..ab174a5a54 100644 --- a/faiss/utils/distances_simd.cpp +++ b/faiss/utils/distances_simd.cpp @@ -10,7 +10,6 @@ #include #include -#include #include #include #include @@ -19,85 +18,28 @@ #include #include -#ifdef __SSE3__ -#include -#endif - -#if defined(__AVX512F__) -#include -#elif defined(__AVX2__) -#include -#endif - -#ifdef __ARM_FEATURE_SVE -#include -#endif - -#ifdef __aarch64__ -#include -#endif +#define AUTOVEC_LEVEL SIMDLevel::NONE +#include namespace faiss { -#ifdef __AVX__ -#define USE_AVX -#endif - -/********************************************************* - * Optimized distance computations - *********************************************************/ - -/* Functions to compute: - - L2 distance between 2 vectors - - inner product between 2 vectors - - L2 norm of a vector - - The functions should probably not be invoked when a large number of - vectors are be processed in batch (in which case Matrix multiply - is faster), but may be useful for comparing vectors isolated in - memory. - - Works with any vectors of any dimension, even unaligned (in which - case they are slower). - +/******* +Functions with SIMDLevel::NONE */ -/********************************************************* - * Reference implementations - */ - -float fvec_L1_ref(const float* x, const float* y, size_t d) { - size_t i; - float res = 0; - for (i = 0; i < d; i++) { - const float tmp = x[i] - y[i]; - res += fabs(tmp); - } - return res; -} - -float fvec_Linf_ref(const float* x, const float* y, size_t d) { - size_t i; - float res = 0; - for (i = 0; i < d; i++) { - res = fmax(res, fabs(x[i] - y[i])); - } - return res; -} - -void fvec_L2sqr_ny_ref( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - for (size_t i = 0; i < ny; i++) { - dis[i] = fvec_L2sqr(x, y, d); - y += d; - } +template <> +void fvec_madd( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + for (size_t i = 0; i < n; i++) + c[i] = a[i] + bf * b[i]; } -void fvec_L2sqr_ny_y_transposed_ref( +template <> +void fvec_L2sqr_ny_transposed( float* dis, const float* x, const float* y, @@ -120,13 +62,50 @@ void fvec_L2sqr_ny_y_transposed_ref( } } -size_t fvec_L2sqr_ny_nearest_ref( +template <> +void fvec_inner_products_ny( + float* ip, + const float* x, + const float* y, + size_t d, + size_t ny) { +// BLAS slower for the use cases here +#if 0 +{ + FINTEGER di = d; + FINTEGER nyi = ny; + float one = 1.0, zero = 0.0; + FINTEGER onei = 1; + sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei); +} +#endif + for (size_t i = 0; i < ny; i++) { + ip[i] = fvec_inner_product(x, y, d); + y += d; + } +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + for (size_t i = 0; i < ny; i++) { + dis[i] = fvec_L2sqr(x, y, d); + y += d; + } +} + +template <> +size_t fvec_L2sqr_ny_nearest( float* distances_tmp_buffer, const float* x, const float* y, size_t d, size_t ny) { - fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny); + fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny); size_t nearest_idx = 0; float min_dis = HUGE_VALF; @@ -141,7 +120,8 @@ size_t fvec_L2sqr_ny_nearest_ref( return nearest_idx; } -size_t fvec_L2sqr_ny_nearest_y_transposed_ref( +template <> +size_t fvec_L2sqr_ny_nearest_y_transposed( float* distances_tmp_buffer, const float* x, const float* y, @@ -149,7 +129,7 @@ size_t fvec_L2sqr_ny_nearest_y_transposed_ref( size_t d, size_t d_offset, size_t ny) { - fvec_L2sqr_ny_y_transposed_ref( + fvec_L2sqr_ny_transposed( distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); size_t nearest_idx = 0; @@ -165,73 +145,54 @@ size_t fvec_L2sqr_ny_nearest_y_transposed_ref( return nearest_idx; } -void fvec_inner_products_ny_ref( - float* ip, - const float* x, - const float* y, - size_t d, - size_t ny) { - // BLAS slower for the use cases here -#if 0 - { - FINTEGER di = d; - FINTEGER nyi = ny; - float one = 1.0, zero = 0.0; - FINTEGER onei = 1; - sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei); - } -#endif - for (size_t i = 0; i < ny; i++) { - ip[i] = fvec_inner_product(x, y, d); - y += d; +template <> +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + float vmin = 1e20; + int imin = -1; + + for (size_t i = 0; i < n; i++) { + c[i] = a[i] + bf * b[i]; + if (c[i] < vmin) { + vmin = c[i]; + imin = i; + } } + return imin; } /********************************************************* - * Autovectorized implementations + * dispatching functions */ -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN -float fvec_inner_product(const float* x, const float* y, size_t d) { - float res = 0.F; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i != d; ++i) { - res += x[i] * y[i]; - } - return res; +float fvec_L1(const float* x, const float* y, size_t d) { + DISPATCH_SIMDLevel(fvec_L1, x, y, d); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN -float fvec_norm_L2sqr(const float* x, size_t d) { - // the double in the _ref is suspected to be a typo. Some of the manual - // implementations this replaces used float. - float res = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i != d; ++i) { - res += x[i] * x[i]; - } +float fvec_Linf(const float* x, const float* y, size_t d) { + DISPATCH_SIMDLevel(fvec_Linf, x, y, d); +} - return res; +// dispatching functions + +float fvec_norm_L2sqr(const float* x, size_t d) { + DISPATCH_SIMDLevel(fvec_norm_L2sqr, x, d); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN float fvec_L2sqr(const float* x, const float* y, size_t d) { - size_t i; - float res = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (i = 0; i < d; i++) { - const float tmp = x[i] - y[i]; - res += tmp * tmp; - } - return res; + DISPATCH_SIMDLevel(fvec_L2sqr, x, y, d); +} + +float fvec_inner_product(const float* x, const float* y, size_t d) { + DISPATCH_SIMDLevel(fvec_inner_product, x, y, d); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END /// Special version of inner product that computes 4 distances /// between x and yi -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN void fvec_inner_product_batch_4( const float* __restrict x, const float* __restrict y0, @@ -243,28 +204,22 @@ void fvec_inner_product_batch_4( float& dis1, float& dis2, float& dis3) { - float d0 = 0; - float d1 = 0; - float d2 = 0; - float d3 = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i < d; ++i) { - d0 += x[i] * y0[i]; - d1 += x[i] * y1[i]; - d2 += x[i] * y2[i]; - d3 += x[i] * y3[i]; - } - - dis0 = d0; - dis1 = d1; - dis2 = d2; - dis3 = d3; + DISPATCH_SIMDLevel( + fvec_inner_product_batch_4, + x, + y0, + y1, + y2, + y3, + d, + dis0, + dis1, + dis2, + dis3); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END /// Special version of L2sqr that computes 4 distances /// between x and yi, which is performance oriented. -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN void fvec_L2sqr_batch_4( const float* x, const float* y0, @@ -276,3326 +231,72 @@ void fvec_L2sqr_batch_4( float& dis1, float& dis2, float& dis3) { - float d0 = 0; - float d1 = 0; - float d2 = 0; - float d3 = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i < d; ++i) { - const float q0 = x[i] - y0[i]; - const float q1 = x[i] - y1[i]; - const float q2 = x[i] - y2[i]; - const float q3 = x[i] - y3[i]; - d0 += q0 * q0; - d1 += q1 * q1; - d2 += q2 * q2; - d3 += q3 * q3; - } - - dis0 = d0; - dis1 = d1; - dis2 = d2; - dis3 = d3; -} -FAISS_PRAGMA_IMPRECISE_FUNCTION_END - -/********************************************************* - * SSE and AVX implementations - */ - -#ifdef __SSE3__ - -// reads 0 <= d < 4 floats as __m128 -static inline __m128 masked_read(int d, const float* x) { - assert(0 <= d && d < 4); - ALIGNED(16) float buf[4] = {0, 0, 0, 0}; - switch (d) { - case 3: - buf[2] = x[2]; - [[fallthrough]]; - case 2: - buf[1] = x[1]; - [[fallthrough]]; - case 1: - buf[0] = x[0]; - } - return _mm_load_ps(buf); - // cannot use AVX2 _mm_mask_set1_epi32 -} - -namespace { - -/// helper function -inline float horizontal_sum(const __m128 v) { - // say, v is [x0, x1, x2, x3] - - // v0 is [x2, x3, ..., ...] - const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); - // v1 is [x0 + x2, x1 + x3, ..., ...] - const __m128 v1 = _mm_add_ps(v, v0); - // v2 is [x1 + x3, ..., .... ,...] - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - // v3 is [x0 + x1 + x2 + x3, ..., ..., ...] - const __m128 v3 = _mm_add_ps(v1, v2); - // return v3[0] - return _mm_cvtss_f32(v3); -} - -#ifdef __AVX2__ -/// helper function for AVX2 -inline float horizontal_sum(const __m256 v) { - // add high and low parts - const __m128 v0 = - _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); - // perform horizontal sum on v0 - return horizontal_sum(v0); -} -#endif - -#ifdef __AVX512F__ -/// helper function for AVX512 -inline float horizontal_sum(const __m512 v) { - // performs better than adding the high and low parts - return _mm512_reduce_add_ps(v); -} -#endif - -/// Function that does a component-wise operation between x and y -/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny -/// functions below -struct ElementOpL2 { - static float op(float x, float y) { - float tmp = x - y; - return tmp * tmp; - } - - static __m128 op(__m128 x, __m128 y) { - __m128 tmp = _mm_sub_ps(x, y); - return _mm_mul_ps(tmp, tmp); - } - -#ifdef __AVX2__ - static __m256 op(__m256 x, __m256 y) { - __m256 tmp = _mm256_sub_ps(x, y); - return _mm256_mul_ps(tmp, tmp); - } -#endif - -#ifdef __AVX512F__ - static __m512 op(__m512 x, __m512 y) { - __m512 tmp = _mm512_sub_ps(x, y); - return _mm512_mul_ps(tmp, tmp); - } -#endif -}; - -/// Function that does a component-wise operation between x and y -/// to compute inner products -struct ElementOpIP { - static float op(float x, float y) { - return x * y; - } - - static __m128 op(__m128 x, __m128 y) { - return _mm_mul_ps(x, y); - } - -#ifdef __AVX2__ - static __m256 op(__m256 x, __m256 y) { - return _mm256_mul_ps(x, y); - } -#endif - -#ifdef __AVX512F__ - static __m512 op(__m512 x, __m512 y) { - return _mm512_mul_ps(x, y); - } -#endif -}; - -template -void fvec_op_ny_D1(float* dis, const float* x, const float* y, size_t ny) { - float x0s = x[0]; - __m128 x0 = _mm_set_ps(x0s, x0s, x0s, x0s); - - size_t i; - for (i = 0; i + 3 < ny; i += 4) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = _mm_cvtss_f32(accu); - __m128 tmp = _mm_shuffle_ps(accu, accu, 1); - dis[i + 1] = _mm_cvtss_f32(tmp); - tmp = _mm_shuffle_ps(accu, accu, 2); - dis[i + 2] = _mm_cvtss_f32(tmp); - tmp = _mm_shuffle_ps(accu, accu, 3); - dis[i + 3] = _mm_cvtss_f32(tmp); - } - while (i < ny) { // handle non-multiple-of-4 case - dis[i++] = ElementOp::op(x0s, *y++); - } -} - -template -void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) { - __m128 x0 = _mm_set_ps(x[1], x[0], x[1], x[0]); - - size_t i; - for (i = 0; i + 1 < ny; i += 2) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - accu = _mm_hadd_ps(accu, accu); - dis[i] = _mm_cvtss_f32(accu); - accu = _mm_shuffle_ps(accu, accu, 3); - dis[i + 1] = _mm_cvtss_f32(accu); - } - if (i < ny) { // handle odd case - dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]); - } + DISPATCH_SIMDLevel( + fvec_L2sqr_batch_4, x, y0, y1, y2, y3, d, dis0, dis1, dis2, dis3); } -#if defined(__AVX512F__) - -template <> -void fvec_op_ny_D2( +void fvec_L2sqr_ny_transposed( float* dis, const float* x, const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D2-vectors per loop. - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - - for (i = 0; i < ny16 * 16; i += 16) { - _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); - - // load 16x2 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - - transpose_16x2( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - v0, - v1); - - // compute distances (dot product) - __m512 distances = _mm512_mul_ps(m0, v0); - distances = _mm512_fmadd_ps(m1, v1, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 32; // move to the next set of 16x2 elements - } - } - - if (i < ny) { - // process leftovers - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float distance = x0 * y[0] + x1 * y[1]; - y += 2; - dis[i] = distance; - } - } + DISPATCH_SIMDLevel( + fvec_L2sqr_ny_transposed, dis, x, y, y_sqlen, d, d_offset, ny); } -template <> -void fvec_op_ny_D2( - float* dis, +void fvec_inner_products_ny( + float* ip, /* output inner product */ const float* x, const float* y, + size_t d, size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D2-vectors per loop. - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - - for (i = 0; i < ny16 * 16; i += 16) { - _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); - - // load 16x2 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - - transpose_16x2( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - v0, - v1); - - // compute differences - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - - // compute squares of differences - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 32; // move to the next set of 16x2 elements - } - } - - if (i < ny) { - // process leftovers - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float sub0 = x0 - y[0]; - float sub1 = x1 - y[1]; - float distance = sub0 * sub0 + sub1 * sub1; - - y += 2; - dis[i] = distance; - } - } + DISPATCH_SIMDLevel(fvec_inner_products_ny, ip, x, y, d, ny); } -#elif defined(__AVX2__) - -template <> -void fvec_op_ny_D2( +void fvec_L2sqr_ny( float* dis, const float* x, const float* y, + size_t d, size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D2-vectors per loop. - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); - - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - - for (i = 0; i < ny8 * 8; i += 8) { - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - // load 8x2 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - - transpose_8x2( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - v0, - v1); - - // compute distances - __m256 distances = _mm256_mul_ps(m0, v0); - distances = _mm256_fmadd_ps(m1, v1, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 16; - } - } - - if (i < ny) { - // process leftovers - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float distance = x0 * y[0] + x1 * y[1]; - y += 2; - dis[i] = distance; - } - } + DISPATCH_SIMDLevel(fvec_L2sqr_ny, dis, x, y, d, ny); } -template <> -void fvec_op_ny_D2( - float* dis, +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, const float* x, const float* y, + size_t d, size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D2-vectors per loop. - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); - - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - - for (i = 0; i < ny8 * 8; i += 8) { - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - // load 8x2 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - - transpose_8x2( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - v0, - v1); + DISPATCH_SIMDLevel( + fvec_L2sqr_ny_nearest, distances_tmp_buffer, x, y, d, ny); +} - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + DISPATCH_SIMDLevel( + fvec_L2sqr_ny_nearest_y_transposed, + distances_tmp_buffer, + x, + y, + y_sqlen, + d, + d_offset, + ny); +} - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 16; - } - } - - if (i < ny) { - // process leftovers - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float sub0 = x0 - y[0]; - float sub1 = x1 - y[1]; - float distance = sub0 * sub0 + sub1 * sub1; - - y += 2; - dis[i] = distance; - } - } -} - -#endif - -template -void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) { - __m128 x0 = _mm_loadu_ps(x); - - for (size_t i = 0; i < ny; i++) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } -} - -#if defined(__AVX512F__) - -template <> -void fvec_op_ny_D4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D4-vectors per loop. - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - - for (i = 0; i < ny16 * 16; i += 16) { - // load 16x4 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - - transpose_16x4( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - v0, - v1, - v2, - v3); - - // compute distances - __m512 distances = _mm512_mul_ps(m0, v0); - distances = _mm512_fmadd_ps(m1, v1, distances); - distances = _mm512_fmadd_ps(m2, v2, distances); - distances = _mm512_fmadd_ps(m3, v3, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 64; // move to the next set of 16x4 elements - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } - } -} - -template <> -void fvec_op_ny_D4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D4-vectors per loop. - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - - for (i = 0; i < ny16 * 16; i += 16) { - // load 16x4 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - - transpose_16x4( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - v0, - v1, - v2, - v3); - - // compute differences - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - const __m512 d2 = _mm512_sub_ps(m2, v2); - const __m512 d3 = _mm512_sub_ps(m3, v3); - - // compute squares of differences - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - distances = _mm512_fmadd_ps(d2, d2, distances); - distances = _mm512_fmadd_ps(d3, d3, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 64; // move to the next set of 16x4 elements - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } - } -} - -#elif defined(__AVX2__) - -template <> -void fvec_op_ny_D4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D4-vectors per loop. - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - - for (i = 0; i < ny8 * 8; i += 8) { - // load 8x4 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - - transpose_8x4( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - v0, - v1, - v2, - v3); - - // compute distances - __m256 distances = _mm256_mul_ps(m0, v0); - distances = _mm256_fmadd_ps(m1, v1, distances); - distances = _mm256_fmadd_ps(m2, v2, distances); - distances = _mm256_fmadd_ps(m3, v3, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 32; - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } - } -} - -template <> -void fvec_op_ny_D4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D4-vectors per loop. - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - - for (i = 0; i < ny8 * 8; i += 8) { - // load 8x4 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - - transpose_8x4( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - v0, - v1, - v2, - v3); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - const __m256 d2 = _mm256_sub_ps(m2, v2); - const __m256 d3 = _mm256_sub_ps(m3, v3); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - distances = _mm256_fmadd_ps(d2, d2, distances); - distances = _mm256_fmadd_ps(d3, d3, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 32; - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } - } -} - -#endif - -template -void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) { - __m128 x0 = _mm_loadu_ps(x); - __m128 x1 = _mm_loadu_ps(x + 4); - - for (size_t i = 0; i < ny; i++) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y))); - y += 4; - accu = _mm_hadd_ps(accu, accu); - accu = _mm_hadd_ps(accu, accu); - dis[i] = _mm_cvtss_f32(accu); - } -} - -#if defined(__AVX512F__) - -template <> -void fvec_op_ny_D8( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D16-vectors per loop. - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - const __m512 m4 = _mm512_set1_ps(x[4]); - const __m512 m5 = _mm512_set1_ps(x[5]); - const __m512 m6 = _mm512_set1_ps(x[6]); - const __m512 m7 = _mm512_set1_ps(x[7]); - - for (i = 0; i < ny16 * 16; i += 16) { - // load 16x8 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - __m512 v4; - __m512 v5; - __m512 v6; - __m512 v7; - - transpose_16x8( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - _mm512_loadu_ps(y + 4 * 16), - _mm512_loadu_ps(y + 5 * 16), - _mm512_loadu_ps(y + 6 * 16), - _mm512_loadu_ps(y + 7 * 16), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute distances - __m512 distances = _mm512_mul_ps(m0, v0); - distances = _mm512_fmadd_ps(m1, v1, distances); - distances = _mm512_fmadd_ps(m2, v2, distances); - distances = _mm512_fmadd_ps(m3, v3, distances); - distances = _mm512_fmadd_ps(m4, v4, distances); - distances = _mm512_fmadd_ps(m5, v5, distances); - distances = _mm512_fmadd_ps(m6, v6, distances); - distances = _mm512_fmadd_ps(m7, v7, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 128; // 16 floats * 8 rows - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y)); - y += 8; - dis[i] = horizontal_sum(accu); - } - } -} - -template <> -void fvec_op_ny_D8( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D16-vectors per loop. - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - const __m512 m4 = _mm512_set1_ps(x[4]); - const __m512 m5 = _mm512_set1_ps(x[5]); - const __m512 m6 = _mm512_set1_ps(x[6]); - const __m512 m7 = _mm512_set1_ps(x[7]); - - for (i = 0; i < ny16 * 16; i += 16) { - // load 16x8 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - __m512 v4; - __m512 v5; - __m512 v6; - __m512 v7; - - transpose_16x8( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - _mm512_loadu_ps(y + 4 * 16), - _mm512_loadu_ps(y + 5 * 16), - _mm512_loadu_ps(y + 6 * 16), - _mm512_loadu_ps(y + 7 * 16), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute differences - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - const __m512 d2 = _mm512_sub_ps(m2, v2); - const __m512 d3 = _mm512_sub_ps(m3, v3); - const __m512 d4 = _mm512_sub_ps(m4, v4); - const __m512 d5 = _mm512_sub_ps(m5, v5); - const __m512 d6 = _mm512_sub_ps(m6, v6); - const __m512 d7 = _mm512_sub_ps(m7, v7); - - // compute squares of differences - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - distances = _mm512_fmadd_ps(d2, d2, distances); - distances = _mm512_fmadd_ps(d3, d3, distances); - distances = _mm512_fmadd_ps(d4, d4, distances); - distances = _mm512_fmadd_ps(d5, d5, distances); - distances = _mm512_fmadd_ps(d6, d6, distances); - distances = _mm512_fmadd_ps(d7, d7, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 128; // 16 floats * 8 rows - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); - y += 8; - dis[i] = horizontal_sum(accu); - } - } -} - -#elif defined(__AVX2__) - -template <> -void fvec_op_ny_D8( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D8-vectors per loop. - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - const __m256 m4 = _mm256_set1_ps(x[4]); - const __m256 m5 = _mm256_set1_ps(x[5]); - const __m256 m6 = _mm256_set1_ps(x[6]); - const __m256 m7 = _mm256_set1_ps(x[7]); - - for (i = 0; i < ny8 * 8; i += 8) { - // load 8x8 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - __m256 v4; - __m256 v5; - __m256 v6; - __m256 v7; - - transpose_8x8( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - _mm256_loadu_ps(y + 4 * 8), - _mm256_loadu_ps(y + 5 * 8), - _mm256_loadu_ps(y + 6 * 8), - _mm256_loadu_ps(y + 7 * 8), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute distances - __m256 distances = _mm256_mul_ps(m0, v0); - distances = _mm256_fmadd_ps(m1, v1, distances); - distances = _mm256_fmadd_ps(m2, v2, distances); - distances = _mm256_fmadd_ps(m3, v3, distances); - distances = _mm256_fmadd_ps(m4, v4, distances); - distances = _mm256_fmadd_ps(m5, v5, distances); - distances = _mm256_fmadd_ps(m6, v6, distances); - distances = _mm256_fmadd_ps(m7, v7, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 64; - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y)); - y += 8; - dis[i] = horizontal_sum(accu); - } - } -} - -template <> -void fvec_op_ny_D8( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D8-vectors per loop. - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - const __m256 m4 = _mm256_set1_ps(x[4]); - const __m256 m5 = _mm256_set1_ps(x[5]); - const __m256 m6 = _mm256_set1_ps(x[6]); - const __m256 m7 = _mm256_set1_ps(x[7]); - - for (i = 0; i < ny8 * 8; i += 8) { - // load 8x8 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - __m256 v4; - __m256 v5; - __m256 v6; - __m256 v7; - - transpose_8x8( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - _mm256_loadu_ps(y + 4 * 8), - _mm256_loadu_ps(y + 5 * 8), - _mm256_loadu_ps(y + 6 * 8), - _mm256_loadu_ps(y + 7 * 8), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - const __m256 d2 = _mm256_sub_ps(m2, v2); - const __m256 d3 = _mm256_sub_ps(m3, v3); - const __m256 d4 = _mm256_sub_ps(m4, v4); - const __m256 d5 = _mm256_sub_ps(m5, v5); - const __m256 d6 = _mm256_sub_ps(m6, v6); - const __m256 d7 = _mm256_sub_ps(m7, v7); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - distances = _mm256_fmadd_ps(d2, d2, distances); - distances = _mm256_fmadd_ps(d3, d3, distances); - distances = _mm256_fmadd_ps(d4, d4, distances); - distances = _mm256_fmadd_ps(d5, d5, distances); - distances = _mm256_fmadd_ps(d6, d6, distances); - distances = _mm256_fmadd_ps(d7, d7, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 64; - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); - y += 8; - dis[i] = horizontal_sum(accu); - } - } -} - -#endif - -template -void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) { - __m128 x0 = _mm_loadu_ps(x); - __m128 x1 = _mm_loadu_ps(x + 4); - __m128 x2 = _mm_loadu_ps(x + 8); - - for (size_t i = 0; i < ny; i++) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y))); - y += 4; - accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y))); - y += 4; - dis[i] = horizontal_sum(accu); - } -} - -} // anonymous namespace - -void fvec_L2sqr_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - // optimized for a few special cases - -#define DISPATCH(dval) \ - case dval: \ - fvec_op_ny_D##dval(dis, x, y, ny); \ - return; - - switch (d) { - DISPATCH(1) - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - DISPATCH(12) - default: - fvec_L2sqr_ny_ref(dis, x, y, d, ny); - return; - } -#undef DISPATCH -} - -void fvec_inner_products_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { -#define DISPATCH(dval) \ - case dval: \ - fvec_op_ny_D##dval(dis, x, y, ny); \ - return; - - switch (d) { - DISPATCH(1) - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - DISPATCH(12) - default: - fvec_inner_products_ny_ref(dis, x, y, d, ny); - return; - } -#undef DISPATCH -} - -#if defined(__AVX512F__) - -template -void fvec_L2sqr_ny_y_transposed_D( - float* distances, - const float* x, - const float* y, - const float* y_sqlen, - const size_t d_offset, - size_t ny) { - // current index being processed - size_t i = 0; - - // squared length of x - float x_sqlen = 0; - for (size_t j = 0; j < DIM; j++) { - x_sqlen += x[j] * x[j]; - } - - // process 16 vectors per loop - const size_t ny16 = ny / 16; - - if (ny16 > 0) { - // m[i] = (2 * x[i], ... 2 * x[i]) - __m512 m[DIM]; - for (size_t j = 0; j < DIM; j++) { - m[j] = _mm512_set1_ps(x[j]); - m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j] - } - - __m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen); - - for (; i < ny16 * 16; i += 16) { - // Load vectors for 16 dimensions - __m512 v[DIM]; - for (size_t j = 0; j < DIM; j++) { - v[j] = _mm512_loadu_ps(y + j * d_offset); - } - - // Compute dot products - __m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm); - for (size_t j = 1; j < DIM; j++) { - dp = _mm512_fnmadd_ps(m[j], v[j], dp); - } - - // Compute y^2 - (2 * x, y) + x^2 - __m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp); - - _mm512_storeu_ps(distances + i, distances_v); - - // Scroll y and y_sqlen forward - y += 16; - y_sqlen += 16; - } - } - - if (i < ny) { - // Process leftovers - for (; i < ny; i++) { - float dp = 0; - for (size_t j = 0; j < DIM; j++) { - dp += x[j] * y[j * d_offset]; - } - - // Compute y^2 - 2 * (x, y), which is sufficient for looking for the - // lowest distance. - const float distance = y_sqlen[0] - 2 * dp + x_sqlen; - distances[i] = distance; - - y += 1; - y_sqlen += 1; - } - } -} - -#elif defined(__AVX2__) - -template -void fvec_L2sqr_ny_y_transposed_D( - float* distances, - const float* x, - const float* y, - const float* y_sqlen, - const size_t d_offset, - size_t ny) { - // current index being processed - size_t i = 0; - - // squared length of x - float x_sqlen = 0; - for (size_t j = 0; j < DIM; j++) { - x_sqlen += x[j] * x[j]; - } - - // process 8 vectors per loop. - const size_t ny8 = ny / 8; - - if (ny8 > 0) { - // m[i] = (2 * x[i], ... 2 * x[i]) - __m256 m[DIM]; - for (size_t j = 0; j < DIM; j++) { - m[j] = _mm256_set1_ps(x[j]); - m[j] = _mm256_add_ps(m[j], m[j]); - } - - __m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen); - - for (; i < ny8 * 8; i += 8) { - // collect dim 0 for 8 D4-vectors. - const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset); - - // compute dot products - // this is x^2 - 2x[0]*y[0] - __m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm); - - for (size_t j = 1; j < DIM; j++) { - // collect dim j for 8 D4-vectors. - const __m256 vj = _mm256_loadu_ps(y + j * d_offset); - dp = _mm256_fnmadd_ps(m[j], vj, dp); - } - - // we've got x^2 - (2x, y) at this point - - // y^2 - (2x, y) + x^2 - __m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp); - - _mm256_storeu_ps(distances + i, distances_v); - - // scroll y and y_sqlen forward. - y += 8; - y_sqlen += 8; - } - } - - if (i < ny) { - // process leftovers - for (; i < ny; i++) { - float dp = 0; - for (size_t j = 0; j < DIM; j++) { - dp += x[j] * y[j * d_offset]; - } - - // compute y^2 - 2 * (x, y), which is sufficient for looking for the - // lowest distance. - const float distance = y_sqlen[0] - 2 * dp + x_sqlen; - distances[i] = distance; - - y += 1; - y_sqlen += 1; - } - } -} - -#endif - -void fvec_L2sqr_ny_transposed( - float* dis, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - // optimized for a few special cases - -#ifdef __AVX2__ -#define DISPATCH(dval) \ - case dval: \ - return fvec_L2sqr_ny_y_transposed_D( \ - dis, x, y, y_sqlen, d_offset, ny); - - switch (d) { - DISPATCH(1) - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - default: - return fvec_L2sqr_ny_y_transposed_ref( - dis, x, y, y_sqlen, d, d_offset, ny); - } -#undef DISPATCH -#else - // non-AVX2 case - return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny); -#endif -} - -#if defined(__AVX512F__) - -size_t fvec_L2sqr_ny_nearest_D2( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - size_t i = 0; - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - const size_t ny16 = ny / 16; - if (ny16 > 0) { - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - __m512 min_distances = _mm512_set1_ps(HUGE_VALF); - __m512i min_indices = _mm512_set1_epi32(0); - - __m512i current_indices = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - const __m512i indices_increment = _mm512_set1_epi32(16); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - - for (; i < ny16 * 16; i += 16) { - _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); - - __m512 v0; - __m512 v1; - - transpose_16x2( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - v0, - v1); - - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - - __mmask16 comparison = - _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); - - min_distances = _mm512_min_ps(distances, min_distances); - min_indices = _mm512_mask_blend_epi32( - comparison, min_indices, current_indices); - - current_indices = - _mm512_add_epi32(current_indices, indices_increment); - - y += 32; - } - - alignas(64) float min_distances_scalar[16]; - alignas(64) uint32_t min_indices_scalar[16]; - _mm512_store_ps(min_distances_scalar, min_distances); - _mm512_store_epi32(min_indices_scalar, min_indices); - - for (size_t j = 0; j < 16; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float sub0 = x0 - y[0]; - float sub1 = x1 - y[1]; - float distance = sub0 * sub0 + sub1 * sub1; - - y += 2; - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -size_t fvec_L2sqr_ny_nearest_D4( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - size_t i = 0; - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - const size_t ny16 = ny / 16; - - if (ny16 > 0) { - __m512 min_distances = _mm512_set1_ps(HUGE_VALF); - __m512i min_indices = _mm512_set1_epi32(0); - - __m512i current_indices = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - const __m512i indices_increment = _mm512_set1_epi32(16); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - - for (; i < ny16 * 16; i += 16) { - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - - transpose_16x4( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - v0, - v1, - v2, - v3); - - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - const __m512 d2 = _mm512_sub_ps(m2, v2); - const __m512 d3 = _mm512_sub_ps(m3, v3); - - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - distances = _mm512_fmadd_ps(d2, d2, distances); - distances = _mm512_fmadd_ps(d3, d3, distances); - - __mmask16 comparison = - _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); - - min_distances = _mm512_min_ps(distances, min_distances); - min_indices = _mm512_mask_blend_epi32( - comparison, min_indices, current_indices); - - current_indices = - _mm512_add_epi32(current_indices, indices_increment); - - y += 64; - } - - alignas(64) float min_distances_scalar[16]; - alignas(64) uint32_t min_indices_scalar[16]; - _mm512_store_ps(min_distances_scalar, min_distances); - _mm512_store_epi32(min_indices_scalar, min_indices); - - for (size_t j = 0; j < 16; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); - y += 4; - const float distance = horizontal_sum(accu); - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -size_t fvec_L2sqr_ny_nearest_D8( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - size_t i = 0; - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - const size_t ny16 = ny / 16; - if (ny16 > 0) { - __m512 min_distances = _mm512_set1_ps(HUGE_VALF); - __m512i min_indices = _mm512_set1_epi32(0); - - __m512i current_indices = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - const __m512i indices_increment = _mm512_set1_epi32(16); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - - const __m512 m4 = _mm512_set1_ps(x[4]); - const __m512 m5 = _mm512_set1_ps(x[5]); - const __m512 m6 = _mm512_set1_ps(x[6]); - const __m512 m7 = _mm512_set1_ps(x[7]); - - for (; i < ny16 * 16; i += 16) { - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - __m512 v4; - __m512 v5; - __m512 v6; - __m512 v7; - - transpose_16x8( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - _mm512_loadu_ps(y + 4 * 16), - _mm512_loadu_ps(y + 5 * 16), - _mm512_loadu_ps(y + 6 * 16), - _mm512_loadu_ps(y + 7 * 16), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - const __m512 d2 = _mm512_sub_ps(m2, v2); - const __m512 d3 = _mm512_sub_ps(m3, v3); - const __m512 d4 = _mm512_sub_ps(m4, v4); - const __m512 d5 = _mm512_sub_ps(m5, v5); - const __m512 d6 = _mm512_sub_ps(m6, v6); - const __m512 d7 = _mm512_sub_ps(m7, v7); - - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - distances = _mm512_fmadd_ps(d2, d2, distances); - distances = _mm512_fmadd_ps(d3, d3, distances); - distances = _mm512_fmadd_ps(d4, d4, distances); - distances = _mm512_fmadd_ps(d5, d5, distances); - distances = _mm512_fmadd_ps(d6, d6, distances); - distances = _mm512_fmadd_ps(d7, d7, distances); - - __mmask16 comparison = - _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); - - min_distances = _mm512_min_ps(distances, min_distances); - min_indices = _mm512_mask_blend_epi32( - comparison, min_indices, current_indices); - - current_indices = - _mm512_add_epi32(current_indices, indices_increment); - - y += 128; - } - - alignas(64) float min_distances_scalar[16]; - alignas(64) uint32_t min_indices_scalar[16]; - _mm512_store_ps(min_distances_scalar, min_distances); - _mm512_store_epi32(min_indices_scalar, min_indices); - - for (size_t j = 0; j < 16; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); - y += 8; - const float distance = horizontal_sum(accu); - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -#elif defined(__AVX2__) - -size_t fvec_L2sqr_ny_nearest_D2( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - // current index being processed - size_t i = 0; - - // min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // process 8 D2-vectors per loop. - const size_t ny8 = ny / 8; - if (ny8 > 0) { - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); - - // track min distance and the closest vector independently - // for each of 8 AVX2 components. - __m256 min_distances = _mm256_set1_ps(HUGE_VALF); - __m256i min_indices = _mm256_set1_epi32(0); - - __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - const __m256i indices_increment = _mm256_set1_epi32(8); - - // 1 value per register - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - - for (; i < ny8 * 8; i += 8) { - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - __m256 v0; - __m256 v1; - - transpose_8x2( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - v0, - v1); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - - // compare the new distances to the min distances - // for each of 8 AVX2 components. - __m256 comparison = - _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); - - // update min distances and indices with closest vectors if needed. - min_distances = _mm256_min_ps(distances, min_distances); - min_indices = _mm256_castps_si256(_mm256_blendv_ps( - _mm256_castsi256_ps(current_indices), - _mm256_castsi256_ps(min_indices), - comparison)); - - // update current indices values. Basically, +8 to each of the - // 8 AVX2 components. - current_indices = - _mm256_add_epi32(current_indices, indices_increment); - - // scroll y forward (8 vectors 2 DIM each). - y += 16; - } - - // dump values and find the minimum distance / minimum index - float min_distances_scalar[8]; - uint32_t min_indices_scalar[8]; - _mm256_storeu_ps(min_distances_scalar, min_distances); - _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 8; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // process leftovers. - // the following code is not optimal, but it is rarely invoked. - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float sub0 = x0 - y[0]; - float sub1 = x1 - y[1]; - float distance = sub0 * sub0 + sub1 * sub1; - - y += 2; - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -size_t fvec_L2sqr_ny_nearest_D4( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - // current index being processed - size_t i = 0; - - // min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // process 8 D4-vectors per loop. - const size_t ny8 = ny / 8; - - if (ny8 > 0) { - // track min distance and the closest vector independently - // for each of 8 AVX2 components. - __m256 min_distances = _mm256_set1_ps(HUGE_VALF); - __m256i min_indices = _mm256_set1_epi32(0); - - __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - const __m256i indices_increment = _mm256_set1_epi32(8); - - // 1 value per register - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - - for (; i < ny8 * 8; i += 8) { - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - - transpose_8x4( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - v0, - v1, - v2, - v3); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - const __m256 d2 = _mm256_sub_ps(m2, v2); - const __m256 d3 = _mm256_sub_ps(m3, v3); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - distances = _mm256_fmadd_ps(d2, d2, distances); - distances = _mm256_fmadd_ps(d3, d3, distances); - - // compare the new distances to the min distances - // for each of 8 AVX2 components. - __m256 comparison = - _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); - - // update min distances and indices with closest vectors if needed. - min_distances = _mm256_min_ps(distances, min_distances); - min_indices = _mm256_castps_si256(_mm256_blendv_ps( - _mm256_castsi256_ps(current_indices), - _mm256_castsi256_ps(min_indices), - comparison)); - - // update current indices values. Basically, +8 to each of the - // 8 AVX2 components. - current_indices = - _mm256_add_epi32(current_indices, indices_increment); - - // scroll y forward (8 vectors 4 DIM each). - y += 32; - } - - // dump values and find the minimum distance / minimum index - float min_distances_scalar[8]; - uint32_t min_indices_scalar[8]; - _mm256_storeu_ps(min_distances_scalar, min_distances); - _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 8; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); - y += 4; - const float distance = horizontal_sum(accu); - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -size_t fvec_L2sqr_ny_nearest_D8( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - // current index being processed - size_t i = 0; - - // min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // process 8 D8-vectors per loop. - const size_t ny8 = ny / 8; - if (ny8 > 0) { - // track min distance and the closest vector independently - // for each of 8 AVX2 components. - __m256 min_distances = _mm256_set1_ps(HUGE_VALF); - __m256i min_indices = _mm256_set1_epi32(0); - - __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - const __m256i indices_increment = _mm256_set1_epi32(8); - - // 1 value per register - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - - const __m256 m4 = _mm256_set1_ps(x[4]); - const __m256 m5 = _mm256_set1_ps(x[5]); - const __m256 m6 = _mm256_set1_ps(x[6]); - const __m256 m7 = _mm256_set1_ps(x[7]); - - for (; i < ny8 * 8; i += 8) { - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - __m256 v4; - __m256 v5; - __m256 v6; - __m256 v7; - - transpose_8x8( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - _mm256_loadu_ps(y + 4 * 8), - _mm256_loadu_ps(y + 5 * 8), - _mm256_loadu_ps(y + 6 * 8), - _mm256_loadu_ps(y + 7 * 8), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - const __m256 d2 = _mm256_sub_ps(m2, v2); - const __m256 d3 = _mm256_sub_ps(m3, v3); - const __m256 d4 = _mm256_sub_ps(m4, v4); - const __m256 d5 = _mm256_sub_ps(m5, v5); - const __m256 d6 = _mm256_sub_ps(m6, v6); - const __m256 d7 = _mm256_sub_ps(m7, v7); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - distances = _mm256_fmadd_ps(d2, d2, distances); - distances = _mm256_fmadd_ps(d3, d3, distances); - distances = _mm256_fmadd_ps(d4, d4, distances); - distances = _mm256_fmadd_ps(d5, d5, distances); - distances = _mm256_fmadd_ps(d6, d6, distances); - distances = _mm256_fmadd_ps(d7, d7, distances); - - // compare the new distances to the min distances - // for each of 8 AVX2 components. - __m256 comparison = - _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); - - // update min distances and indices with closest vectors if needed. - min_distances = _mm256_min_ps(distances, min_distances); - min_indices = _mm256_castps_si256(_mm256_blendv_ps( - _mm256_castsi256_ps(current_indices), - _mm256_castsi256_ps(min_indices), - comparison)); - - // update current indices values. Basically, +8 to each of the - // 8 AVX2 components. - current_indices = - _mm256_add_epi32(current_indices, indices_increment); - - // scroll y forward (8 vectors 8 DIM each). - y += 64; - } - - // dump values and find the minimum distance / minimum index - float min_distances_scalar[8]; - uint32_t min_indices_scalar[8]; - _mm256_storeu_ps(min_distances_scalar, min_distances); - _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 8; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); - y += 8; - const float distance = horizontal_sum(accu); - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -#else -size_t fvec_L2sqr_ny_nearest_D2( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 2, ny); -} - -size_t fvec_L2sqr_ny_nearest_D4( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny); -} - -size_t fvec_L2sqr_ny_nearest_D8( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 8, ny); -} -#endif - -size_t fvec_L2sqr_ny_nearest( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t d, - size_t ny) { - // optimized for a few special cases -#define DISPATCH(dval) \ - case dval: \ - return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny); - - switch (d) { - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - default: - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); - } -#undef DISPATCH -} - -#if defined(__AVX512F__) - -template -size_t fvec_L2sqr_ny_nearest_y_transposed_D( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - const size_t d_offset, - size_t ny) { - // This implementation does not use distances_tmp_buffer. - - // Current index being processed - size_t i = 0; - - // Min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // Process 16 vectors per loop - const size_t ny16 = ny / 16; - - if (ny16 > 0) { - // Track min distance and the closest vector independently - // for each of 16 AVX-512 components. - __m512 min_distances = _mm512_set1_ps(HUGE_VALF); - __m512i min_indices = _mm512_set1_epi32(0); - - __m512i current_indices = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - const __m512i indices_increment = _mm512_set1_epi32(16); - - // m[i] = (2 * x[i], ... 2 * x[i]) - __m512 m[DIM]; - for (size_t j = 0; j < DIM; j++) { - m[j] = _mm512_set1_ps(x[j]); - m[j] = _mm512_add_ps(m[j], m[j]); - } - - for (; i < ny16 * 16; i += 16) { - // Compute dot products - const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset); - __m512 dp = _mm512_mul_ps(m[0], v0); - for (size_t j = 1; j < DIM; j++) { - const __m512 vj = _mm512_loadu_ps(y + j * d_offset); - dp = _mm512_fmadd_ps(m[j], vj, dp); - } - - // Compute y^2 - (2 * x, y), which is sufficient for looking for the - // lowest distance. - // x^2 is the constant that can be avoided. - const __m512 distances = - _mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp); - - // Compare the new distances to the min distances - __mmask16 comparison = - _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); - - // Update min distances and indices with closest vectors if needed - min_distances = - _mm512_mask_blend_ps(comparison, distances, min_distances); - min_indices = _mm512_castps_si512(_mm512_mask_blend_ps( - comparison, - _mm512_castsi512_ps(current_indices), - _mm512_castsi512_ps(min_indices))); - - // Update current indices values. Basically, +16 to each of the 16 - // AVX-512 components. - current_indices = - _mm512_add_epi32(current_indices, indices_increment); - - // Scroll y and y_sqlen forward. - y += 16; - y_sqlen += 16; - } - - // Dump values and find the minimum distance / minimum index - float min_distances_scalar[16]; - uint32_t min_indices_scalar[16]; - _mm512_storeu_ps(min_distances_scalar, min_distances); - _mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 16; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // Process leftovers - for (; i < ny; i++) { - float dp = 0; - for (size_t j = 0; j < DIM; j++) { - dp += x[j] * y[j * d_offset]; - } - - // Compute y^2 - 2 * (x, y), which is sufficient for looking for the - // lowest distance. - const float distance = y_sqlen[0] - 2 * dp; - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - - y += 1; - y_sqlen += 1; - } - } - - return current_min_index; -} - -#elif defined(__AVX2__) - -template -size_t fvec_L2sqr_ny_nearest_y_transposed_D( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - const size_t d_offset, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - // current index being processed - size_t i = 0; - - // min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // process 8 vectors per loop. - const size_t ny8 = ny / 8; - - if (ny8 > 0) { - // track min distance and the closest vector independently - // for each of 8 AVX2 components. - __m256 min_distances = _mm256_set1_ps(HUGE_VALF); - __m256i min_indices = _mm256_set1_epi32(0); - - __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - const __m256i indices_increment = _mm256_set1_epi32(8); - - // m[i] = (2 * x[i], ... 2 * x[i]) - __m256 m[DIM]; - for (size_t j = 0; j < DIM; j++) { - m[j] = _mm256_set1_ps(x[j]); - m[j] = _mm256_add_ps(m[j], m[j]); - } - - for (; i < ny8 * 8; i += 8) { - // collect dim 0 for 8 D4-vectors. - const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset); - // compute dot products - __m256 dp = _mm256_mul_ps(m[0], v0); - - for (size_t j = 1; j < DIM; j++) { - // collect dim j for 8 D4-vectors. - const __m256 vj = _mm256_loadu_ps(y + j * d_offset); - dp = _mm256_fmadd_ps(m[j], vj, dp); - } - - // compute y^2 - (2 * x, y), which is sufficient for looking for the - // lowest distance. - // x^2 is the constant that can be avoided. - const __m256 distances = - _mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp); - - // compare the new distances to the min distances - // for each of 8 AVX2 components. - const __m256 comparison = - _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); - - // update min distances and indices with closest vectors if needed. - min_distances = - _mm256_blendv_ps(distances, min_distances, comparison); - min_indices = _mm256_castps_si256(_mm256_blendv_ps( - _mm256_castsi256_ps(current_indices), - _mm256_castsi256_ps(min_indices), - comparison)); - - // update current indices values. Basically, +8 to each of the - // 8 AVX2 components. - current_indices = - _mm256_add_epi32(current_indices, indices_increment); - - // scroll y and y_sqlen forward. - y += 8; - y_sqlen += 8; - } - - // dump values and find the minimum distance / minimum index - float min_distances_scalar[8]; - uint32_t min_indices_scalar[8]; - _mm256_storeu_ps(min_distances_scalar, min_distances); - _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 8; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // process leftovers - for (; i < ny; i++) { - float dp = 0; - for (size_t j = 0; j < DIM; j++) { - dp += x[j] * y[j * d_offset]; - } - - // compute y^2 - 2 * (x, y), which is sufficient for looking for the - // lowest distance. - const float distance = y_sqlen[0] - 2 * dp; - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - - y += 1; - y_sqlen += 1; - } - } - - return current_min_index; -} - -#endif - -size_t fvec_L2sqr_ny_nearest_y_transposed( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - // optimized for a few special cases -#ifdef __AVX2__ -#define DISPATCH(dval) \ - case dval: \ - return fvec_L2sqr_ny_nearest_y_transposed_D( \ - distances_tmp_buffer, x, y, y_sqlen, d_offset, ny); - - switch (d) { - DISPATCH(1) - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - default: - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); - } -#undef DISPATCH -#else - // non-AVX2 case - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); -#endif -} - -#endif - -#ifdef USE_AVX - -float fvec_L1(const float* x, const float* y, size_t d) { - __m256 msum1 = _mm256_setzero_ps(); - // signmask used for absolute value - __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL)); - - while (d >= 8) { - __m256 mx = _mm256_loadu_ps(x); - x += 8; - __m256 my = _mm256_loadu_ps(y); - y += 8; - // subtract - const __m256 a_m_b = _mm256_sub_ps(mx, my); - // find sum of absolute value of distances (manhattan distance) - msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b)); - d -= 8; - } - - __m128 msum2 = _mm256_extractf128_ps(msum1, 1); - msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0)); - __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL)); - - if (d >= 4) { - __m128 mx = _mm_loadu_ps(x); - x += 4; - __m128 my = _mm_loadu_ps(y); - y += 4; - const __m128 a_m_b = _mm_sub_ps(mx, my); - msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b)); - d -= 4; - } - - if (d > 0) { - __m128 mx = masked_read(d, x); - __m128 my = masked_read(d, y); - __m128 a_m_b = _mm_sub_ps(mx, my); - msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b)); - } - - msum2 = _mm_hadd_ps(msum2, msum2); - msum2 = _mm_hadd_ps(msum2, msum2); - return _mm_cvtss_f32(msum2); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - __m256 msum1 = _mm256_setzero_ps(); - // signmask used for absolute value - __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL)); - - while (d >= 8) { - __m256 mx = _mm256_loadu_ps(x); - x += 8; - __m256 my = _mm256_loadu_ps(y); - y += 8; - // subtract - const __m256 a_m_b = _mm256_sub_ps(mx, my); - // find max of absolute value of distances (chebyshev distance) - msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b)); - d -= 8; - } - - __m128 msum2 = _mm256_extractf128_ps(msum1, 1); - msum2 = _mm_max_ps(msum2, _mm256_extractf128_ps(msum1, 0)); - __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL)); - - if (d >= 4) { - __m128 mx = _mm_loadu_ps(x); - x += 4; - __m128 my = _mm_loadu_ps(y); - y += 4; - const __m128 a_m_b = _mm_sub_ps(mx, my); - msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); - d -= 4; - } - - if (d > 0) { - __m128 mx = masked_read(d, x); - __m128 my = masked_read(d, y); - __m128 a_m_b = _mm_sub_ps(mx, my); - msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); - } - - msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2); - msum2 = _mm_max_ps(msum2, _mm_shuffle_ps(msum2, msum2, 1)); - return _mm_cvtss_f32(msum2); -} - -#elif defined(__SSE3__) // But not AVX - -float fvec_L1(const float* x, const float* y, size_t d) { - return fvec_L1_ref(x, y, d); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - return fvec_Linf_ref(x, y, d); -} - -#elif defined(__ARM_FEATURE_SVE) - -struct ElementOpIP { - static svfloat32_t op(svbool_t pg, svfloat32_t x, svfloat32_t y) { - return svmul_f32_x(pg, x, y); - } - static svfloat32_t merge( - svbool_t pg, - svfloat32_t z, - svfloat32_t x, - svfloat32_t y) { - return svmla_f32_x(pg, z, x, y); - } -}; - -template -void fvec_op_ny_sve_d1(float* dis, const float* x, const float* y, size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svdup_n_f32(x[0]); - size_t i = 0; - for (; i + lanes4 < ny; i += lanes4) { - svfloat32_t y0 = svld1_f32(pg, y); - svfloat32_t y1 = svld1_f32(pg, y + lanes); - svfloat32_t y2 = svld1_f32(pg, y + lanes2); - svfloat32_t y3 = svld1_f32(pg, y + lanes3); - y0 = ElementOp::op(pg, x0, y0); - y1 = ElementOp::op(pg, x0, y1); - y2 = ElementOp::op(pg, x0, y2); - y3 = ElementOp::op(pg, x0, y3); - svst1_f32(pg, dis, y0); - svst1_f32(pg, dis + lanes, y1); - svst1_f32(pg, dis + lanes2, y2); - svst1_f32(pg, dis + lanes3, y3); - y += lanes4; - dis += lanes4; - } - const svbool_t pg0 = svwhilelt_b32_u64(i, ny); - const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny); - const svbool_t pg2 = svwhilelt_b32_u64(i + lanes2, ny); - const svbool_t pg3 = svwhilelt_b32_u64(i + lanes3, ny); - svfloat32_t y0 = svld1_f32(pg0, y); - svfloat32_t y1 = svld1_f32(pg1, y + lanes); - svfloat32_t y2 = svld1_f32(pg2, y + lanes2); - svfloat32_t y3 = svld1_f32(pg3, y + lanes3); - y0 = ElementOp::op(pg0, x0, y0); - y1 = ElementOp::op(pg1, x0, y1); - y2 = ElementOp::op(pg2, x0, y2); - y3 = ElementOp::op(pg3, x0, y3); - svst1_f32(pg0, dis, y0); - svst1_f32(pg1, dis + lanes, y1); - svst1_f32(pg2, dis + lanes2, y2); - svst1_f32(pg3, dis + lanes3, y3); -} - -template -void fvec_op_ny_sve_d2(float* dis, const float* x, const float* y, size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svdup_n_f32(x[0]); - const svfloat32_t x1 = svdup_n_f32(x[1]); - size_t i = 0; - for (; i + lanes2 < ny; i += lanes2) { - const svfloat32x2_t y0 = svld2_f32(pg, y); - const svfloat32x2_t y1 = svld2_f32(pg, y + lanes2); - svfloat32_t y00 = svget2_f32(y0, 0); - const svfloat32_t y01 = svget2_f32(y0, 1); - svfloat32_t y10 = svget2_f32(y1, 0); - const svfloat32_t y11 = svget2_f32(y1, 1); - y00 = ElementOp::op(pg, x0, y00); - y10 = ElementOp::op(pg, x0, y10); - y00 = ElementOp::merge(pg, y00, x1, y01); - y10 = ElementOp::merge(pg, y10, x1, y11); - svst1_f32(pg, dis, y00); - svst1_f32(pg, dis + lanes, y10); - y += lanes4; - dis += lanes2; - } - const svbool_t pg0 = svwhilelt_b32_u64(i, ny); - const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny); - const svfloat32x2_t y0 = svld2_f32(pg0, y); - const svfloat32x2_t y1 = svld2_f32(pg1, y + lanes2); - svfloat32_t y00 = svget2_f32(y0, 0); - const svfloat32_t y01 = svget2_f32(y0, 1); - svfloat32_t y10 = svget2_f32(y1, 0); - const svfloat32_t y11 = svget2_f32(y1, 1); - y00 = ElementOp::op(pg0, x0, y00); - y10 = ElementOp::op(pg1, x0, y10); - y00 = ElementOp::merge(pg0, y00, x1, y01); - y10 = ElementOp::merge(pg1, y10, x1, y11); - svst1_f32(pg0, dis, y00); - svst1_f32(pg1, dis + lanes, y10); -} - -template -void fvec_op_ny_sve_d4(float* dis, const float* x, const float* y, size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svdup_n_f32(x[0]); - const svfloat32_t x1 = svdup_n_f32(x[1]); - const svfloat32_t x2 = svdup_n_f32(x[2]); - const svfloat32_t x3 = svdup_n_f32(x[3]); - size_t i = 0; - for (; i + lanes < ny; i += lanes) { - const svfloat32x4_t y0 = svld4_f32(pg, y); - svfloat32_t y00 = svget4_f32(y0, 0); - const svfloat32_t y01 = svget4_f32(y0, 1); - svfloat32_t y02 = svget4_f32(y0, 2); - const svfloat32_t y03 = svget4_f32(y0, 3); - y00 = ElementOp::op(pg, x0, y00); - y02 = ElementOp::op(pg, x2, y02); - y00 = ElementOp::merge(pg, y00, x1, y01); - y02 = ElementOp::merge(pg, y02, x3, y03); - y00 = svadd_f32_x(pg, y00, y02); - svst1_f32(pg, dis, y00); - y += lanes4; - dis += lanes; - } - const svbool_t pg0 = svwhilelt_b32_u64(i, ny); - const svfloat32x4_t y0 = svld4_f32(pg0, y); - svfloat32_t y00 = svget4_f32(y0, 0); - const svfloat32_t y01 = svget4_f32(y0, 1); - svfloat32_t y02 = svget4_f32(y0, 2); - const svfloat32_t y03 = svget4_f32(y0, 3); - y00 = ElementOp::op(pg0, x0, y00); - y02 = ElementOp::op(pg0, x2, y02); - y00 = ElementOp::merge(pg0, y00, x1, y01); - y02 = ElementOp::merge(pg0, y02, x3, y03); - y00 = svadd_f32_x(pg0, y00, y02); - svst1_f32(pg0, dis, y00); -} - -template -void fvec_op_ny_sve_d8(float* dis, const float* x, const float* y, size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes4 = lanes * 4; - const size_t lanes8 = lanes * 8; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svdup_n_f32(x[0]); - const svfloat32_t x1 = svdup_n_f32(x[1]); - const svfloat32_t x2 = svdup_n_f32(x[2]); - const svfloat32_t x3 = svdup_n_f32(x[3]); - const svfloat32_t x4 = svdup_n_f32(x[4]); - const svfloat32_t x5 = svdup_n_f32(x[5]); - const svfloat32_t x6 = svdup_n_f32(x[6]); - const svfloat32_t x7 = svdup_n_f32(x[7]); - size_t i = 0; - for (; i + lanes < ny; i += lanes) { - const svfloat32x4_t ya = svld4_f32(pg, y); - const svfloat32x4_t yb = svld4_f32(pg, y + lanes4); - const svfloat32_t ya0 = svget4_f32(ya, 0); - const svfloat32_t ya1 = svget4_f32(ya, 1); - const svfloat32_t ya2 = svget4_f32(ya, 2); - const svfloat32_t ya3 = svget4_f32(ya, 3); - const svfloat32_t yb0 = svget4_f32(yb, 0); - const svfloat32_t yb1 = svget4_f32(yb, 1); - const svfloat32_t yb2 = svget4_f32(yb, 2); - const svfloat32_t yb3 = svget4_f32(yb, 3); - svfloat32_t y0 = svuzp1(ya0, yb0); - const svfloat32_t y1 = svuzp1(ya1, yb1); - svfloat32_t y2 = svuzp1(ya2, yb2); - const svfloat32_t y3 = svuzp1(ya3, yb3); - svfloat32_t y4 = svuzp2(ya0, yb0); - const svfloat32_t y5 = svuzp2(ya1, yb1); - svfloat32_t y6 = svuzp2(ya2, yb2); - const svfloat32_t y7 = svuzp2(ya3, yb3); - y0 = ElementOp::op(pg, x0, y0); - y2 = ElementOp::op(pg, x2, y2); - y4 = ElementOp::op(pg, x4, y4); - y6 = ElementOp::op(pg, x6, y6); - y0 = ElementOp::merge(pg, y0, x1, y1); - y2 = ElementOp::merge(pg, y2, x3, y3); - y4 = ElementOp::merge(pg, y4, x5, y5); - y6 = ElementOp::merge(pg, y6, x7, y7); - y0 = svadd_f32_x(pg, y0, y2); - y4 = svadd_f32_x(pg, y4, y6); - y0 = svadd_f32_x(pg, y0, y4); - svst1_f32(pg, dis, y0); - y += lanes8; - dis += lanes; - } - const svbool_t pg0 = svwhilelt_b32_u64(i, ny); - const svbool_t pga = svwhilelt_b32_u64(i * 2, ny * 2); - const svbool_t pgb = svwhilelt_b32_u64(i * 2 + lanes, ny * 2); - const svfloat32x4_t ya = svld4_f32(pga, y); - const svfloat32x4_t yb = svld4_f32(pgb, y + lanes4); - const svfloat32_t ya0 = svget4_f32(ya, 0); - const svfloat32_t ya1 = svget4_f32(ya, 1); - const svfloat32_t ya2 = svget4_f32(ya, 2); - const svfloat32_t ya3 = svget4_f32(ya, 3); - const svfloat32_t yb0 = svget4_f32(yb, 0); - const svfloat32_t yb1 = svget4_f32(yb, 1); - const svfloat32_t yb2 = svget4_f32(yb, 2); - const svfloat32_t yb3 = svget4_f32(yb, 3); - svfloat32_t y0 = svuzp1(ya0, yb0); - const svfloat32_t y1 = svuzp1(ya1, yb1); - svfloat32_t y2 = svuzp1(ya2, yb2); - const svfloat32_t y3 = svuzp1(ya3, yb3); - svfloat32_t y4 = svuzp2(ya0, yb0); - const svfloat32_t y5 = svuzp2(ya1, yb1); - svfloat32_t y6 = svuzp2(ya2, yb2); - const svfloat32_t y7 = svuzp2(ya3, yb3); - y0 = ElementOp::op(pg0, x0, y0); - y2 = ElementOp::op(pg0, x2, y2); - y4 = ElementOp::op(pg0, x4, y4); - y6 = ElementOp::op(pg0, x6, y6); - y0 = ElementOp::merge(pg0, y0, x1, y1); - y2 = ElementOp::merge(pg0, y2, x3, y3); - y4 = ElementOp::merge(pg0, y4, x5, y5); - y6 = ElementOp::merge(pg0, y6, x7, y7); - y0 = svadd_f32_x(pg0, y0, y2); - y4 = svadd_f32_x(pg0, y4, y6); - y0 = svadd_f32_x(pg0, y0, y4); - svst1_f32(pg0, dis, y0); - y += lanes8; - dis += lanes; -} - -template -void fvec_op_ny_sve_lanes1( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svld1_f32(pg, x); - size_t i = 0; - for (; i + 3 < ny; i += 4) { - svfloat32_t y0 = svld1_f32(pg, y); - svfloat32_t y1 = svld1_f32(pg, y + lanes); - svfloat32_t y2 = svld1_f32(pg, y + lanes2); - svfloat32_t y3 = svld1_f32(pg, y + lanes3); - y += lanes4; - y0 = ElementOp::op(pg, x0, y0); - y1 = ElementOp::op(pg, x0, y1); - y2 = ElementOp::op(pg, x0, y2); - y3 = ElementOp::op(pg, x0, y3); - dis[i] = svaddv_f32(pg, y0); - dis[i + 1] = svaddv_f32(pg, y1); - dis[i + 2] = svaddv_f32(pg, y2); - dis[i + 3] = svaddv_f32(pg, y3); - } - for (; i < ny; ++i) { - svfloat32_t y0 = svld1_f32(pg, y); - y += lanes; - y0 = ElementOp::op(pg, x0, y0); - dis[i] = svaddv_f32(pg, y0); - } -} - -template -void fvec_op_ny_sve_lanes2( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svld1_f32(pg, x); - const svfloat32_t x1 = svld1_f32(pg, x + lanes); - size_t i = 0; - for (; i + 1 < ny; i += 2) { - svfloat32_t y00 = svld1_f32(pg, y); - const svfloat32_t y01 = svld1_f32(pg, y + lanes); - svfloat32_t y10 = svld1_f32(pg, y + lanes2); - const svfloat32_t y11 = svld1_f32(pg, y + lanes3); - y += lanes4; - y00 = ElementOp::op(pg, x0, y00); - y10 = ElementOp::op(pg, x0, y10); - y00 = ElementOp::merge(pg, y00, x1, y01); - y10 = ElementOp::merge(pg, y10, x1, y11); - dis[i] = svaddv_f32(pg, y00); - dis[i + 1] = svaddv_f32(pg, y10); - } - if (i < ny) { - svfloat32_t y0 = svld1_f32(pg, y); - const svfloat32_t y1 = svld1_f32(pg, y + lanes); - y0 = ElementOp::op(pg, x0, y0); - y0 = ElementOp::merge(pg, y0, x1, y1); - dis[i] = svaddv_f32(pg, y0); - } -} - -template -void fvec_op_ny_sve_lanes3( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svld1_f32(pg, x); - const svfloat32_t x1 = svld1_f32(pg, x + lanes); - const svfloat32_t x2 = svld1_f32(pg, x + lanes2); - for (size_t i = 0; i < ny; ++i) { - svfloat32_t y0 = svld1_f32(pg, y); - const svfloat32_t y1 = svld1_f32(pg, y + lanes); - svfloat32_t y2 = svld1_f32(pg, y + lanes2); - y += lanes3; - y0 = ElementOp::op(pg, x0, y0); - y0 = ElementOp::merge(pg, y0, x1, y1); - y0 = ElementOp::merge(pg, y0, x2, y2); - dis[i] = svaddv_f32(pg, y0); - } -} - -template -void fvec_op_ny_sve_lanes4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svld1_f32(pg, x); - const svfloat32_t x1 = svld1_f32(pg, x + lanes); - const svfloat32_t x2 = svld1_f32(pg, x + lanes2); - const svfloat32_t x3 = svld1_f32(pg, x + lanes3); - for (size_t i = 0; i < ny; ++i) { - svfloat32_t y0 = svld1_f32(pg, y); - const svfloat32_t y1 = svld1_f32(pg, y + lanes); - svfloat32_t y2 = svld1_f32(pg, y + lanes2); - const svfloat32_t y3 = svld1_f32(pg, y + lanes3); - y += lanes4; - y0 = ElementOp::op(pg, x0, y0); - y2 = ElementOp::op(pg, x2, y2); - y0 = ElementOp::merge(pg, y0, x1, y1); - y2 = ElementOp::merge(pg, y2, x3, y3); - y0 = svadd_f32_x(pg, y0, y2); - dis[i] = svaddv_f32(pg, y0); - } -} - -void fvec_L2sqr_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_L2sqr_ny_ref(dis, x, y, d, ny); -} - -void fvec_L2sqr_ny_transposed( - float* dis, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny); -} - -size_t fvec_L2sqr_ny_nearest( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t d, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); -} - -size_t fvec_L2sqr_ny_nearest_y_transposed( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); -} - -float fvec_L1(const float* x, const float* y, size_t d) { - return fvec_L1_ref(x, y, d); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - return fvec_Linf_ref(x, y, d); -} - -void fvec_inner_products_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - const size_t lanes = svcntw(); - switch (d) { - case 1: - fvec_op_ny_sve_d1(dis, x, y, ny); - break; - case 2: - fvec_op_ny_sve_d2(dis, x, y, ny); - break; - case 4: - fvec_op_ny_sve_d4(dis, x, y, ny); - break; - case 8: - fvec_op_ny_sve_d8(dis, x, y, ny); - break; - default: - if (d == lanes) - fvec_op_ny_sve_lanes1(dis, x, y, ny); - else if (d == lanes * 2) - fvec_op_ny_sve_lanes2(dis, x, y, ny); - else if (d == lanes * 3) - fvec_op_ny_sve_lanes3(dis, x, y, ny); - else if (d == lanes * 4) - fvec_op_ny_sve_lanes4(dis, x, y, ny); - else - fvec_inner_products_ny_ref(dis, x, y, d, ny); - break; - } -} - -#elif defined(__aarch64__) - -// not optimized for ARM -void fvec_L2sqr_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_L2sqr_ny_ref(dis, x, y, d, ny); -} - -void fvec_L2sqr_ny_transposed( - float* dis, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny); -} - -size_t fvec_L2sqr_ny_nearest( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t d, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); -} - -size_t fvec_L2sqr_ny_nearest_y_transposed( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); -} - -float fvec_L1(const float* x, const float* y, size_t d) { - return fvec_L1_ref(x, y, d); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - return fvec_Linf_ref(x, y, d); -} - -void fvec_inner_products_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_inner_products_ny_ref(dis, x, y, d, ny); -} - -#else -// scalar implementation - -float fvec_L1(const float* x, const float* y, size_t d) { - return fvec_L1_ref(x, y, d); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - return fvec_Linf_ref(x, y, d); -} - -void fvec_L2sqr_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_L2sqr_ny_ref(dis, x, y, d, ny); -} - -void fvec_L2sqr_ny_transposed( - float* dis, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny); -} - -size_t fvec_L2sqr_ny_nearest( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t d, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); -} - -size_t fvec_L2sqr_ny_nearest_y_transposed( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); -} - -void fvec_inner_products_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_inner_products_ny_ref(dis, x, y, d, ny); -} - -#endif - -/*************************************************************************** - * heavily optimized table computations - ***************************************************************************/ - -[[maybe_unused]] static inline void fvec_madd_ref( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - for (size_t i = 0; i < n; i++) { - c[i] = a[i] + bf * b[i]; - } -} - -#if defined(__AVX512F__) - -static inline void fvec_madd_avx512( - const size_t n, - const float* __restrict a, - const float bf, - const float* __restrict b, - float* __restrict c) { - const size_t n16 = n / 16; - const size_t n_for_masking = n % 16; - - const __m512 bfmm = _mm512_set1_ps(bf); - - size_t idx = 0; - for (idx = 0; idx < n16 * 16; idx += 16) { - const __m512 ax = _mm512_loadu_ps(a + idx); - const __m512 bx = _mm512_loadu_ps(b + idx); - const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); - _mm512_storeu_ps(c + idx, abmul); - } - - if (n_for_masking > 0) { - const __mmask16 mask = (1 << n_for_masking) - 1; - - const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx); - const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx); - const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); - _mm512_mask_storeu_ps(c + idx, mask, abmul); - } -} - -#elif defined(__AVX2__) - -static inline void fvec_madd_avx2( - const size_t n, - const float* __restrict a, - const float bf, - const float* __restrict b, - float* __restrict c) { - // - const size_t n8 = n / 8; - const size_t n_for_masking = n % 8; - - const __m256 bfmm = _mm256_set1_ps(bf); - - size_t idx = 0; - for (idx = 0; idx < n8 * 8; idx += 8) { - const __m256 ax = _mm256_loadu_ps(a + idx); - const __m256 bx = _mm256_loadu_ps(b + idx); - const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); - _mm256_storeu_ps(c + idx, abmul); - } - - if (n_for_masking > 0) { - __m256i mask; - switch (n_for_masking) { - case 1: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1); - break; - case 2: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1); - break; - case 3: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1); - break; - case 4: - mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1); - break; - case 5: - mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1); - break; - case 6: - mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1); - break; - case 7: - mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1); - break; - } - - const __m256 ax = _mm256_maskload_ps(a + idx, mask); - const __m256 bx = _mm256_maskload_ps(b + idx, mask); - const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); - _mm256_maskstore_ps(c + idx, mask, abmul); - } -} - -#endif - -#ifdef __SSE3__ - -[[maybe_unused]] static inline void fvec_madd_sse( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - n >>= 2; - __m128 bf4 = _mm_set_ps1(bf); - __m128* a4 = (__m128*)a; - __m128* b4 = (__m128*)b; - __m128* c4 = (__m128*)c; - - while (n--) { - *c4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4)); - b4++; - a4++; - c4++; - } -} - -void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { -#ifdef __AVX512F__ - fvec_madd_avx512(n, a, bf, b, c); -#elif __AVX2__ - fvec_madd_avx2(n, a, bf, b, c); -#else - if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) - fvec_madd_sse(n, a, bf, b, c); - else - fvec_madd_ref(n, a, bf, b, c); -#endif -} - -#elif defined(__ARM_FEATURE_SVE) - -void fvec_madd( - const size_t n, - const float* __restrict a, - const float bf, - const float* __restrict b, - float* __restrict c) { - const size_t lanes = static_cast(svcntw()); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - size_t i = 0; - for (; i + lanes4 < n; i += lanes4) { - const auto mask = svptrue_b32(); - const auto ai0 = svld1_f32(mask, a + i); - const auto ai1 = svld1_f32(mask, a + i + lanes); - const auto ai2 = svld1_f32(mask, a + i + lanes2); - const auto ai3 = svld1_f32(mask, a + i + lanes3); - const auto bi0 = svld1_f32(mask, b + i); - const auto bi1 = svld1_f32(mask, b + i + lanes); - const auto bi2 = svld1_f32(mask, b + i + lanes2); - const auto bi3 = svld1_f32(mask, b + i + lanes3); - const auto ci0 = svmla_n_f32_x(mask, ai0, bi0, bf); - const auto ci1 = svmla_n_f32_x(mask, ai1, bi1, bf); - const auto ci2 = svmla_n_f32_x(mask, ai2, bi2, bf); - const auto ci3 = svmla_n_f32_x(mask, ai3, bi3, bf); - svst1_f32(mask, c + i, ci0); - svst1_f32(mask, c + i + lanes, ci1); - svst1_f32(mask, c + i + lanes2, ci2); - svst1_f32(mask, c + i + lanes3, ci3); - } - const auto mask0 = svwhilelt_b32_u64(i, n); - const auto mask1 = svwhilelt_b32_u64(i + lanes, n); - const auto mask2 = svwhilelt_b32_u64(i + lanes2, n); - const auto mask3 = svwhilelt_b32_u64(i + lanes3, n); - const auto ai0 = svld1_f32(mask0, a + i); - const auto ai1 = svld1_f32(mask1, a + i + lanes); - const auto ai2 = svld1_f32(mask2, a + i + lanes2); - const auto ai3 = svld1_f32(mask3, a + i + lanes3); - const auto bi0 = svld1_f32(mask0, b + i); - const auto bi1 = svld1_f32(mask1, b + i + lanes); - const auto bi2 = svld1_f32(mask2, b + i + lanes2); - const auto bi3 = svld1_f32(mask3, b + i + lanes3); - const auto ci0 = svmla_n_f32_x(mask0, ai0, bi0, bf); - const auto ci1 = svmla_n_f32_x(mask1, ai1, bi1, bf); - const auto ci2 = svmla_n_f32_x(mask2, ai2, bi2, bf); - const auto ci3 = svmla_n_f32_x(mask3, ai3, bi3, bf); - svst1_f32(mask0, c + i, ci0); - svst1_f32(mask1, c + i + lanes, ci1); - svst1_f32(mask2, c + i + lanes2, ci2); - svst1_f32(mask3, c + i + lanes3, ci3); -} - -#elif defined(__aarch64__) - -void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { - const size_t n_simd = n - (n & 3); - const float32x4_t bfv = vdupq_n_f32(bf); - size_t i; - for (i = 0; i < n_simd; i += 4) { - const float32x4_t ai = vld1q_f32(a + i); - const float32x4_t bi = vld1q_f32(b + i); - const float32x4_t ci = vfmaq_f32(ai, bfv, bi); - vst1q_f32(c + i, ci); - } - for (; i < n; ++i) - c[i] = a[i] + bf * b[i]; -} - -#else - -void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { - fvec_madd_ref(n, a, bf, b, c); -} - -#endif - -static inline int fvec_madd_and_argmin_ref( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - float vmin = 1e20; - int imin = -1; - - for (size_t i = 0; i < n; i++) { - c[i] = a[i] + bf * b[i]; - if (c[i] < vmin) { - vmin = c[i]; - imin = i; - } - } - return imin; -} - -#ifdef __SSE3__ - -static inline int fvec_madd_and_argmin_sse( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - n >>= 2; - __m128 bf4 = _mm_set_ps1(bf); - __m128 vmin4 = _mm_set_ps1(1e20); - __m128i imin4 = _mm_set1_epi32(-1); - __m128i idx4 = _mm_set_epi32(3, 2, 1, 0); - __m128i inc4 = _mm_set1_epi32(4); - __m128* a4 = (__m128*)a; - __m128* b4 = (__m128*)b; - __m128* c4 = (__m128*)c; - - while (n--) { - __m128 vc4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4)); - *c4 = vc4; - __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); - // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower! - - imin4 = _mm_or_si128( - _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); - vmin4 = _mm_min_ps(vmin4, vc4); - b4++; - a4++; - c4++; - idx4 = _mm_add_epi32(idx4, inc4); - } - - // 4 values -> 2 - { - idx4 = _mm_shuffle_epi32(imin4, 3 << 2 | 2); - __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 3 << 2 | 2); - __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); - imin4 = _mm_or_si128( - _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); - vmin4 = _mm_min_ps(vmin4, vc4); - } - // 2 values -> 1 - { - idx4 = _mm_shuffle_epi32(imin4, 1); - __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 1); - __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); - imin4 = _mm_or_si128( - _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); - // vmin4 = _mm_min_ps (vmin4, vc4); - } - return _mm_cvtsi128_si32(imin4); -} - -int fvec_madd_and_argmin( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) { - return fvec_madd_and_argmin_sse(n, a, bf, b, c); - } else { - return fvec_madd_and_argmin_ref(n, a, bf, b, c); - } -} - -#elif defined(__aarch64__) - -int fvec_madd_and_argmin( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - float32x4_t vminv = vdupq_n_f32(1e20); - uint32x4_t iminv = vdupq_n_u32(static_cast(-1)); - size_t i; - { - const size_t n_simd = n - (n & 3); - const uint32_t iota[] = {0, 1, 2, 3}; - uint32x4_t iv = vld1q_u32(iota); - const uint32x4_t incv = vdupq_n_u32(4); - const float32x4_t bfv = vdupq_n_f32(bf); - for (i = 0; i < n_simd; i += 4) { - const float32x4_t ai = vld1q_f32(a + i); - const float32x4_t bi = vld1q_f32(b + i); - const float32x4_t ci = vfmaq_f32(ai, bfv, bi); - vst1q_f32(c + i, ci); - const uint32x4_t less_than = vcltq_f32(ci, vminv); - vminv = vminq_f32(ci, vminv); - iminv = vorrq_u32( - vandq_u32(less_than, iv), - vandq_u32(vmvnq_u32(less_than), iminv)); - iv = vaddq_u32(iv, incv); - } - } - float vmin = vminvq_f32(vminv); - uint32_t imin; - { - const float32x4_t vminy = vdupq_n_f32(vmin); - const uint32x4_t equals = vceqq_f32(vminv, vminy); - imin = vminvq_u32(vorrq_u32( - vandq_u32(equals, iminv), - vandq_u32( - vmvnq_u32(equals), - vdupq_n_u32(std::numeric_limits::max())))); - } - for (; i < n; ++i) { - c[i] = a[i] + bf * b[i]; - if (c[i] < vmin) { - vmin = c[i]; - imin = static_cast(i); - } - } - return static_cast(imin); -} - -#else +void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { + DISPATCH_SIMDLevel(fvec_madd, n, a, bf, b, c); +} int fvec_madd_and_argmin( size_t n, @@ -3603,11 +304,9 @@ int fvec_madd_and_argmin( float bf, const float* b, float* c) { - return fvec_madd_and_argmin_ref(n, a, bf, b, c); + DISPATCH_SIMDLevel(fvec_madd_and_argmin, n, a, bf, b, c); } -#endif - /*************************************************************************** * PQ tables computations ***************************************************************************/ diff --git a/faiss/utils/extra_distances-inl.h b/faiss/utils/extra_distances-inl.h index 4fcab576fd..572f87b980 100644 --- a/faiss/utils/extra_distances-inl.h +++ b/faiss/utils/extra_distances-inl.h @@ -60,13 +60,6 @@ inline float VectorDistance::operator()( const float* x, const float* y) const { return fvec_Linf(x, y, d); - /* - float vmax = 0; - for (size_t i = 0; i < d; i++) { - float diff = fabs (x[i] - y[i]); - if (diff > vmax) vmax = diff; - } - return vmax;*/ } template <> diff --git a/faiss/utils/hamming_distance/generic-inl.h b/faiss/utils/hamming_distance/generic-inl.h index b8e7b42c9c..f7abd52390 100644 --- a/faiss/utils/hamming_distance/generic-inl.h +++ b/faiss/utils/hamming_distance/generic-inl.h @@ -312,7 +312,6 @@ struct HammingComputerDefault { const uint8_t* a = a8 + 8 * quotient8; const uint8_t* b = b8 + 8 * quotient8; switch (remainder8) { - [[fallthrough]]; case 7: accu += hamdis_tab_ham_bytes[a[6] ^ b[6]]; [[fallthrough]]; diff --git a/faiss/utils/simd_impl/distances_aarch64.cpp b/faiss/utils/simd_impl/distances_aarch64.cpp new file mode 100644 index 0000000000..33ad9bbc4f --- /dev/null +++ b/faiss/utils/simd_impl/distances_aarch64.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#define AUTOVEC_LEVEL SIMDLevel::ARM_NEON +#include + +namespace faiss { + +template <> +void fvec_madd( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + const size_t n_simd = n - (n & 3); + const float32x4_t bfv = vdupq_n_f32(bf); + size_t i; + for (i = 0; i < n_simd; i += 4) { + const float32x4_t ai = vld1q_f32(a + i); + const float32x4_t bi = vld1q_f32(b + i); + const float32x4_t ci = vfmaq_f32(ai, bfv, bi); + vst1q_f32(c + i, ci); + } + for (; i < n; ++i) + c[i] = a[i] + bf * b[i]; +} + +template <> +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny); + +template <> +void fvec_inner_products_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_inner_products_ny(dis, x, y, d, ny); +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny(dis, x, y, d, ny); +} + +template <> +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny_nearest(distances_tmp_buffer, x, y, d, ny); +} + +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + return fvec_L2sqr_ny_nearest_y_transposed_ref( + distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); +} + +template <> +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + float32x4_t vminv = vdupq_n_f32(1e20); + uint32x4_t iminv = vdupq_n_u32(static_cast(-1)); + size_t i; + { + const size_t n_simd = n - (n & 3); + const uint32_t iota[] = {0, 1, 2, 3}; + uint32x4_t iv = vld1q_u32(iota); + const uint32x4_t incv = vdupq_n_u32(4); + const float32x4_t bfv = vdupq_n_f32(bf); + for (i = 0; i < n_simd; i += 4) { + const float32x4_t ai = vld1q_f32(a + i); + const float32x4_t bi = vld1q_f32(b + i); + const float32x4_t ci = vfmaq_f32(ai, bfv, bi); + vst1q_f32(c + i, ci); + const uint32x4_t less_than = vcltq_f32(ci, vminv); + vminv = vminq_f32(ci, vminv); + iminv = vorrq_u32( + vandq_u32(less_than, iv), + vandq_u32(vmvnq_u32(less_than), iminv)); + iv = vaddq_u32(iv, incv); + } + } + float vmin = vminvq_f32(vminv); + uint32_t imin; + { + const float32x4_t vminy = vdupq_n_f32(vmin); + const uint32x4_t equals = vceqq_f32(vminv, vminy); + imin = vminvq_u32(vorrq_u32( + vandq_u32(equals, iminv), + vandq_u32( + vmvnq_u32(equals), + vdupq_n_u32(std::numeric_limits::max())))); + } + for (; i < n; ++i) { + c[i] = a[i] + bf * b[i]; + if (c[i] < vmin) { + vmin = c[i]; + imin = static_cast(i); + } + } + return static_cast(imin); +} + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_arm_sve.cpp b/faiss/utils/simd_impl/distances_arm_sve.cpp new file mode 100644 index 0000000000..3bd4227da0 --- /dev/null +++ b/faiss/utils/simd_impl/distances_arm_sve.cpp @@ -0,0 +1,496 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#define AUTOVEC_LEVEL SIMDLevel::ARM_SVE +#include + +namespace faiss { + +template <> +void fvec_madd( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + const size_t lanes = static_cast(svcntw()); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + size_t i = 0; + for (; i + lanes4 < n; i += lanes4) { + const auto mask = svptrue_b32(); + const auto ai0 = svld1_f32(mask, a + i); + const auto ai1 = svld1_f32(mask, a + i + lanes); + const auto ai2 = svld1_f32(mask, a + i + lanes2); + const auto ai3 = svld1_f32(mask, a + i + lanes3); + const auto bi0 = svld1_f32(mask, b + i); + const auto bi1 = svld1_f32(mask, b + i + lanes); + const auto bi2 = svld1_f32(mask, b + i + lanes2); + const auto bi3 = svld1_f32(mask, b + i + lanes3); + const auto ci0 = svmla_n_f32_x(mask, ai0, bi0, bf); + const auto ci1 = svmla_n_f32_x(mask, ai1, bi1, bf); + const auto ci2 = svmla_n_f32_x(mask, ai2, bi2, bf); + const auto ci3 = svmla_n_f32_x(mask, ai3, bi3, bf); + svst1_f32(mask, c + i, ci0); + svst1_f32(mask, c + i + lanes, ci1); + svst1_f32(mask, c + i + lanes2, ci2); + svst1_f32(mask, c + i + lanes3, ci3); + } + const auto mask0 = svwhilelt_b32_u64(i, n); + const auto mask1 = svwhilelt_b32_u64(i + lanes, n); + const auto mask2 = svwhilelt_b32_u64(i + lanes2, n); + const auto mask3 = svwhilelt_b32_u64(i + lanes3, n); + const auto ai0 = svld1_f32(mask0, a + i); + const auto ai1 = svld1_f32(mask1, a + i + lanes); + const auto ai2 = svld1_f32(mask2, a + i + lanes2); + const auto ai3 = svld1_f32(mask3, a + i + lanes3); + const auto bi0 = svld1_f32(mask0, b + i); + const auto bi1 = svld1_f32(mask1, b + i + lanes); + const auto bi2 = svld1_f32(mask2, b + i + lanes2); + const auto bi3 = svld1_f32(mask3, b + i + lanes3); + const auto ci0 = svmla_n_f32_x(mask0, ai0, bi0, bf); + const auto ci1 = svmla_n_f32_x(mask1, ai1, bi1, bf); + const auto ci2 = svmla_n_f32_x(mask2, ai2, bi2, bf); + const auto ci3 = svmla_n_f32_x(mask3, ai3, bi3, bf); + svst1_f32(mask0, c + i, ci0); + svst1_f32(mask1, c + i + lanes, ci1); + svst1_f32(mask2, c + i + lanes2, ci2); + svst1_f32(mask3, c + i + lanes3, ci3); +} + +template <> +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny); + +struct ElementOpIP { + static svfloat32_t op(svbool_t pg, svfloat32_t x, svfloat32_t y) { + return svmul_f32_x(pg, x, y); + } + static svfloat32_t merge( + svbool_t pg, + svfloat32_t z, + svfloat32_t x, + svfloat32_t y) { + return svmla_f32_x(pg, z, x, y); + } +}; + +template +void fvec_op_ny_sve_d1(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + size_t i = 0; + for (; i + lanes4 < ny; i += lanes4) { + svfloat32_t y0 = svld1_f32(pg, y); + svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + svfloat32_t y3 = svld1_f32(pg, y + lanes3); + y0 = ElementOp::op(pg, x0, y0); + y1 = ElementOp::op(pg, x0, y1); + y2 = ElementOp::op(pg, x0, y2); + y3 = ElementOp::op(pg, x0, y3); + svst1_f32(pg, dis, y0); + svst1_f32(pg, dis + lanes, y1); + svst1_f32(pg, dis + lanes2, y2); + svst1_f32(pg, dis + lanes3, y3); + y += lanes4; + dis += lanes4; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny); + const svbool_t pg2 = svwhilelt_b32_u64(i + lanes2, ny); + const svbool_t pg3 = svwhilelt_b32_u64(i + lanes3, ny); + svfloat32_t y0 = svld1_f32(pg0, y); + svfloat32_t y1 = svld1_f32(pg1, y + lanes); + svfloat32_t y2 = svld1_f32(pg2, y + lanes2); + svfloat32_t y3 = svld1_f32(pg3, y + lanes3); + y0 = ElementOp::op(pg0, x0, y0); + y1 = ElementOp::op(pg1, x0, y1); + y2 = ElementOp::op(pg2, x0, y2); + y3 = ElementOp::op(pg3, x0, y3); + svst1_f32(pg0, dis, y0); + svst1_f32(pg1, dis + lanes, y1); + svst1_f32(pg2, dis + lanes2, y2); + svst1_f32(pg3, dis + lanes3, y3); +} + +template +void fvec_op_ny_sve_d2(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + const svfloat32_t x1 = svdup_n_f32(x[1]); + size_t i = 0; + for (; i + lanes2 < ny; i += lanes2) { + const svfloat32x2_t y0 = svld2_f32(pg, y); + const svfloat32x2_t y1 = svld2_f32(pg, y + lanes2); + svfloat32_t y00 = svget2_f32(y0, 0); + const svfloat32_t y01 = svget2_f32(y0, 1); + svfloat32_t y10 = svget2_f32(y1, 0); + const svfloat32_t y11 = svget2_f32(y1, 1); + y00 = ElementOp::op(pg, x0, y00); + y10 = ElementOp::op(pg, x0, y10); + y00 = ElementOp::merge(pg, y00, x1, y01); + y10 = ElementOp::merge(pg, y10, x1, y11); + svst1_f32(pg, dis, y00); + svst1_f32(pg, dis + lanes, y10); + y += lanes4; + dis += lanes2; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny); + const svfloat32x2_t y0 = svld2_f32(pg0, y); + const svfloat32x2_t y1 = svld2_f32(pg1, y + lanes2); + svfloat32_t y00 = svget2_f32(y0, 0); + const svfloat32_t y01 = svget2_f32(y0, 1); + svfloat32_t y10 = svget2_f32(y1, 0); + const svfloat32_t y11 = svget2_f32(y1, 1); + y00 = ElementOp::op(pg0, x0, y00); + y10 = ElementOp::op(pg1, x0, y10); + y00 = ElementOp::merge(pg0, y00, x1, y01); + y10 = ElementOp::merge(pg1, y10, x1, y11); + svst1_f32(pg0, dis, y00); + svst1_f32(pg1, dis + lanes, y10); +} + +template +void fvec_op_ny_sve_d4(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + const svfloat32_t x1 = svdup_n_f32(x[1]); + const svfloat32_t x2 = svdup_n_f32(x[2]); + const svfloat32_t x3 = svdup_n_f32(x[3]); + size_t i = 0; + for (; i + lanes < ny; i += lanes) { + const svfloat32x4_t y0 = svld4_f32(pg, y); + svfloat32_t y00 = svget4_f32(y0, 0); + const svfloat32_t y01 = svget4_f32(y0, 1); + svfloat32_t y02 = svget4_f32(y0, 2); + const svfloat32_t y03 = svget4_f32(y0, 3); + y00 = ElementOp::op(pg, x0, y00); + y02 = ElementOp::op(pg, x2, y02); + y00 = ElementOp::merge(pg, y00, x1, y01); + y02 = ElementOp::merge(pg, y02, x3, y03); + y00 = svadd_f32_x(pg, y00, y02); + svst1_f32(pg, dis, y00); + y += lanes4; + dis += lanes; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svfloat32x4_t y0 = svld4_f32(pg0, y); + svfloat32_t y00 = svget4_f32(y0, 0); + const svfloat32_t y01 = svget4_f32(y0, 1); + svfloat32_t y02 = svget4_f32(y0, 2); + const svfloat32_t y03 = svget4_f32(y0, 3); + y00 = ElementOp::op(pg0, x0, y00); + y02 = ElementOp::op(pg0, x2, y02); + y00 = ElementOp::merge(pg0, y00, x1, y01); + y02 = ElementOp::merge(pg0, y02, x3, y03); + y00 = svadd_f32_x(pg0, y00, y02); + svst1_f32(pg0, dis, y00); +} + +template +void fvec_op_ny_sve_d8(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes4 = lanes * 4; + const size_t lanes8 = lanes * 8; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + const svfloat32_t x1 = svdup_n_f32(x[1]); + const svfloat32_t x2 = svdup_n_f32(x[2]); + const svfloat32_t x3 = svdup_n_f32(x[3]); + const svfloat32_t x4 = svdup_n_f32(x[4]); + const svfloat32_t x5 = svdup_n_f32(x[5]); + const svfloat32_t x6 = svdup_n_f32(x[6]); + const svfloat32_t x7 = svdup_n_f32(x[7]); + size_t i = 0; + for (; i + lanes < ny; i += lanes) { + const svfloat32x4_t ya = svld4_f32(pg, y); + const svfloat32x4_t yb = svld4_f32(pg, y + lanes4); + const svfloat32_t ya0 = svget4_f32(ya, 0); + const svfloat32_t ya1 = svget4_f32(ya, 1); + const svfloat32_t ya2 = svget4_f32(ya, 2); + const svfloat32_t ya3 = svget4_f32(ya, 3); + const svfloat32_t yb0 = svget4_f32(yb, 0); + const svfloat32_t yb1 = svget4_f32(yb, 1); + const svfloat32_t yb2 = svget4_f32(yb, 2); + const svfloat32_t yb3 = svget4_f32(yb, 3); + svfloat32_t y0 = svuzp1(ya0, yb0); + const svfloat32_t y1 = svuzp1(ya1, yb1); + svfloat32_t y2 = svuzp1(ya2, yb2); + const svfloat32_t y3 = svuzp1(ya3, yb3); + svfloat32_t y4 = svuzp2(ya0, yb0); + const svfloat32_t y5 = svuzp2(ya1, yb1); + svfloat32_t y6 = svuzp2(ya2, yb2); + const svfloat32_t y7 = svuzp2(ya3, yb3); + y0 = ElementOp::op(pg, x0, y0); + y2 = ElementOp::op(pg, x2, y2); + y4 = ElementOp::op(pg, x4, y4); + y6 = ElementOp::op(pg, x6, y6); + y0 = ElementOp::merge(pg, y0, x1, y1); + y2 = ElementOp::merge(pg, y2, x3, y3); + y4 = ElementOp::merge(pg, y4, x5, y5); + y6 = ElementOp::merge(pg, y6, x7, y7); + y0 = svadd_f32_x(pg, y0, y2); + y4 = svadd_f32_x(pg, y4, y6); + y0 = svadd_f32_x(pg, y0, y4); + svst1_f32(pg, dis, y0); + y += lanes8; + dis += lanes; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svbool_t pga = svwhilelt_b32_u64(i * 2, ny * 2); + const svbool_t pgb = svwhilelt_b32_u64(i * 2 + lanes, ny * 2); + const svfloat32x4_t ya = svld4_f32(pga, y); + const svfloat32x4_t yb = svld4_f32(pgb, y + lanes4); + const svfloat32_t ya0 = svget4_f32(ya, 0); + const svfloat32_t ya1 = svget4_f32(ya, 1); + const svfloat32_t ya2 = svget4_f32(ya, 2); + const svfloat32_t ya3 = svget4_f32(ya, 3); + const svfloat32_t yb0 = svget4_f32(yb, 0); + const svfloat32_t yb1 = svget4_f32(yb, 1); + const svfloat32_t yb2 = svget4_f32(yb, 2); + const svfloat32_t yb3 = svget4_f32(yb, 3); + svfloat32_t y0 = svuzp1(ya0, yb0); + const svfloat32_t y1 = svuzp1(ya1, yb1); + svfloat32_t y2 = svuzp1(ya2, yb2); + const svfloat32_t y3 = svuzp1(ya3, yb3); + svfloat32_t y4 = svuzp2(ya0, yb0); + const svfloat32_t y5 = svuzp2(ya1, yb1); + svfloat32_t y6 = svuzp2(ya2, yb2); + const svfloat32_t y7 = svuzp2(ya3, yb3); + y0 = ElementOp::op(pg0, x0, y0); + y2 = ElementOp::op(pg0, x2, y2); + y4 = ElementOp::op(pg0, x4, y4); + y6 = ElementOp::op(pg0, x6, y6); + y0 = ElementOp::merge(pg0, y0, x1, y1); + y2 = ElementOp::merge(pg0, y2, x3, y3); + y4 = ElementOp::merge(pg0, y4, x5, y5); + y6 = ElementOp::merge(pg0, y6, x7, y7); + y0 = svadd_f32_x(pg0, y0, y2); + y4 = svadd_f32_x(pg0, y4, y6); + y0 = svadd_f32_x(pg0, y0, y4); + svst1_f32(pg0, dis, y0); + y += lanes8; + dis += lanes; +} + +template +void fvec_op_ny_sve_lanes1( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + size_t i = 0; + for (; i + 3 < ny; i += 4) { + svfloat32_t y0 = svld1_f32(pg, y); + svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + svfloat32_t y3 = svld1_f32(pg, y + lanes3); + y += lanes4; + y0 = ElementOp::op(pg, x0, y0); + y1 = ElementOp::op(pg, x0, y1); + y2 = ElementOp::op(pg, x0, y2); + y3 = ElementOp::op(pg, x0, y3); + dis[i] = svaddv_f32(pg, y0); + dis[i + 1] = svaddv_f32(pg, y1); + dis[i + 2] = svaddv_f32(pg, y2); + dis[i + 3] = svaddv_f32(pg, y3); + } + for (; i < ny; ++i) { + svfloat32_t y0 = svld1_f32(pg, y); + y += lanes; + y0 = ElementOp::op(pg, x0, y0); + dis[i] = svaddv_f32(pg, y0); + } +} + +template +void fvec_op_ny_sve_lanes2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + const svfloat32_t x1 = svld1_f32(pg, x + lanes); + size_t i = 0; + for (; i + 1 < ny; i += 2) { + svfloat32_t y00 = svld1_f32(pg, y); + const svfloat32_t y01 = svld1_f32(pg, y + lanes); + svfloat32_t y10 = svld1_f32(pg, y + lanes2); + const svfloat32_t y11 = svld1_f32(pg, y + lanes3); + y += lanes4; + y00 = ElementOp::op(pg, x0, y00); + y10 = ElementOp::op(pg, x0, y10); + y00 = ElementOp::merge(pg, y00, x1, y01); + y10 = ElementOp::merge(pg, y10, x1, y11); + dis[i] = svaddv_f32(pg, y00); + dis[i + 1] = svaddv_f32(pg, y10); + } + if (i < ny) { + svfloat32_t y0 = svld1_f32(pg, y); + const svfloat32_t y1 = svld1_f32(pg, y + lanes); + y0 = ElementOp::op(pg, x0, y0); + y0 = ElementOp::merge(pg, y0, x1, y1); + dis[i] = svaddv_f32(pg, y0); + } +} + +template +void fvec_op_ny_sve_lanes3( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + const svfloat32_t x1 = svld1_f32(pg, x + lanes); + const svfloat32_t x2 = svld1_f32(pg, x + lanes2); + for (size_t i = 0; i < ny; ++i) { + svfloat32_t y0 = svld1_f32(pg, y); + const svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + y += lanes3; + y0 = ElementOp::op(pg, x0, y0); + y0 = ElementOp::merge(pg, y0, x1, y1); + y0 = ElementOp::merge(pg, y0, x2, y2); + dis[i] = svaddv_f32(pg, y0); + } +} + +template +void fvec_op_ny_sve_lanes4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + const svfloat32_t x1 = svld1_f32(pg, x + lanes); + const svfloat32_t x2 = svld1_f32(pg, x + lanes2); + const svfloat32_t x3 = svld1_f32(pg, x + lanes3); + for (size_t i = 0; i < ny; ++i) { + svfloat32_t y0 = svld1_f32(pg, y); + const svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + const svfloat32_t y3 = svld1_f32(pg, y + lanes3); + y += lanes4; + y0 = ElementOp::op(pg, x0, y0); + y2 = ElementOp::op(pg, x2, y2); + y0 = ElementOp::merge(pg, y0, x1, y1); + y2 = ElementOp::merge(pg, y2, x3, y3); + y0 = svadd_f32_x(pg, y0, y2); + dis[i] = svaddv_f32(pg, y0); + } +} + +template <> +void fvec_inner_products_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + const size_t lanes = svcntw(); + switch (d) { + case 1: + fvec_op_ny_sve_d1(dis, x, y, ny); + break; + case 2: + fvec_op_ny_sve_d2(dis, x, y, ny); + break; + case 4: + fvec_op_ny_sve_d4(dis, x, y, ny); + break; + case 8: + fvec_op_ny_sve_d8(dis, x, y, ny); + break; + default: + if (d == lanes) + fvec_op_ny_sve_lanes1(dis, x, y, ny); + else if (d == lanes * 2) + fvec_op_ny_sve_lanes2(dis, x, y, ny); + else if (d == lanes * 3) + fvec_op_ny_sve_lanes3(dis, x, y, ny); + else if (d == lanes * 4) + fvec_op_ny_sve_lanes4(dis, x, y, ny); + else + fvec_inner_products_ny(dis, x, y, d, ny); + break; + } +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny(dis, x, y, d, ny); +} + +template <> +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny_nearest( + distances_tmp_buffer, x, y, d, ny); +} + +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + return fvec_L2sqr_ny_nearest_y_transposed( + distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); +} + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_autovec-inl.h b/faiss/utils/simd_impl/distances_autovec-inl.h new file mode 100644 index 0000000000..62d13eb38e --- /dev/null +++ b/faiss/utils/simd_impl/distances_autovec-inl.h @@ -0,0 +1,153 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace faiss { + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_norm_L2sqr(const float* x, size_t d) { + // the double in the _ref is suspected to be a typo. Some of the manual + // implementations this replaces used float. + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i != d; ++i) { + res += x[i] * x[i]; + } + + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_L2sqr(const float* x, const float* y, size_t d) { + size_t i; + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (i = 0; i < d; i++) { + const float tmp = x[i] - y[i]; + res += tmp * tmp; + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_inner_product( + const float* x, + const float* y, + size_t d) { + float res = 0.F; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i != d; ++i) { + res += x[i] * y[i]; + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_L1(const float* x, const float* y, size_t d) { + size_t i; + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (i = 0; i < d; i++) { + const float tmp = x[i] - y[i]; + res += fabs(tmp); + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_Linf(const float* x, const float* y, size_t d) { + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; i++) { + res = fmax(res, fabs(x[i] - y[i])); + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +void fvec_inner_product_batch_4( + const float* __restrict x, + const float* __restrict y0, + const float* __restrict y1, + const float* __restrict y2, + const float* __restrict y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + d0 += x[i] * y0[i]; + d1 += x[i] * y1[i]; + d2 += x[i] * y2[i]; + d3 += x[i] * y3[i]; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +void fvec_L2sqr_batch_4( + const float* x, + const float* y0, + const float* y1, + const float* y2, + const float* y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + const float q0 = x[i] - y0[i]; + const float q1 = x[i] - y1[i]; + const float q2 = x[i] - y2[i]; + const float q3 = x[i] - y3[i]; + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_avx.cpp b/faiss/utils/simd_impl/distances_avx.cpp new file mode 100644 index 0000000000..c29e64c91f --- /dev/null +++ b/faiss/utils/simd_impl/distances_avx.cpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#ifdef __AVX__ + +float fvec_L1(const float* x, const float* y, size_t d) { + __m256 msum1 = _mm256_setzero_ps(); + // signmask used for absolute value + __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL)); + + while (d >= 8) { + __m256 mx = _mm256_loadu_ps(x); + x += 8; + __m256 my = _mm256_loadu_ps(y); + y += 8; + // subtract + const __m256 a_m_b = _mm256_sub_ps(mx, my); + // find sum of absolute value of distances (manhattan distance) + msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b)); + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0)); + __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL)); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps(x); + x += 4; + __m128 my = _mm_loadu_ps(y); + y += 4; + const __m128 a_m_b = _mm_sub_ps(mx, my); + msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read(d, x); + __m128 my = masked_read(d, y); + __m128 a_m_b = _mm_sub_ps(mx, my); + msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + } + + msum2 = _mm_hadd_ps(msum2, msum2); + msum2 = _mm_hadd_ps(msum2, msum2); + return _mm_cvtss_f32(msum2); +} + +float fvec_Linf(const float* x, const float* y, size_t d) { + __m256 msum1 = _mm256_setzero_ps(); + // signmask used for absolute value + __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL)); + + while (d >= 8) { + __m256 mx = _mm256_loadu_ps(x); + x += 8; + __m256 my = _mm256_loadu_ps(y); + y += 8; + // subtract + const __m256 a_m_b = _mm256_sub_ps(mx, my); + // find max of absolute value of distances (chebyshev distance) + msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b)); + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 = _mm_max_ps(msum2, _mm256_extractf128_ps(msum1, 0)); + __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL)); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps(x); + x += 4; + __m128 my = _mm_loadu_ps(y); + y += 4; + const __m128 a_m_b = _mm_sub_ps(mx, my); + msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read(d, x); + __m128 my = masked_read(d, y); + __m128 a_m_b = _mm_sub_ps(mx, my); + msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + } + + msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2); + msum2 = _mm_max_ps(msum2, _mm_shuffle_ps(msum2, msum2, 1)); + return _mm_cvtss_f32(msum2); +} + +#endif diff --git a/faiss/utils/simd_impl/distances_avx2.cpp b/faiss/utils/simd_impl/distances_avx2.cpp new file mode 100644 index 0000000000..acfcbabe17 --- /dev/null +++ b/faiss/utils/simd_impl/distances_avx2.cpp @@ -0,0 +1,1178 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#define AUTOVEC_LEVEL SIMDLevel::AVX2 +#include + +#include +#include + +namespace faiss { + +template <> +void fvec_madd( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + // + const size_t n8 = n / 8; + const size_t n_for_masking = n % 8; + + const __m256 bfmm = _mm256_set1_ps(bf); + + size_t idx = 0; + for (idx = 0; idx < n8 * 8; idx += 8) { + const __m256 ax = _mm256_loadu_ps(a + idx); + const __m256 bx = _mm256_loadu_ps(b + idx); + const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); + _mm256_storeu_ps(c + idx, abmul); + } + + if (n_for_masking > 0) { + __m256i mask; + switch (n_for_masking) { + case 1: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1); + break; + case 2: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1); + break; + case 3: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1); + break; + case 4: + mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1); + break; + case 5: + mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1); + break; + case 6: + mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1); + break; + case 7: + mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1); + break; + } + + const __m256 ax = _mm256_maskload_ps(a + idx, mask); + const __m256 bx = _mm256_maskload_ps(b + idx, mask); + const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); + _mm256_maskstore_ps(c + idx, mask, abmul); + } +} + +template +void fvec_L2sqr_ny_y_transposed_D( + float* distances, + const float* x, + const float* y, + const float* y_sqlen, + const size_t d_offset, + size_t ny) { + // current index being processed + size_t i = 0; + + // squared length of x + float x_sqlen = 0; + for (size_t j = 0; j < DIM; j++) { + x_sqlen += x[j] * x[j]; + } + + // process 8 vectors per loop. + const size_t ny8 = ny / 8; + + if (ny8 > 0) { + // m[i] = (2 * x[i], ... 2 * x[i]) + __m256 m[DIM]; + for (size_t j = 0; j < DIM; j++) { + m[j] = _mm256_set1_ps(x[j]); + m[j] = _mm256_add_ps(m[j], m[j]); + } + + __m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen); + + for (; i < ny8 * 8; i += 8) { + // collect dim 0 for 8 D4-vectors. + const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset); + + // compute dot products + // this is x^2 - 2x[0]*y[0] + __m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm); + + for (size_t j = 1; j < DIM; j++) { + // collect dim j for 8 D4-vectors. + const __m256 vj = _mm256_loadu_ps(y + j * d_offset); + dp = _mm256_fnmadd_ps(m[j], vj, dp); + } + + // we've got x^2 - (2x, y) at this point + + // y^2 - (2x, y) + x^2 + __m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp); + + _mm256_storeu_ps(distances + i, distances_v); + + // scroll y and y_sqlen forward. + y += 8; + y_sqlen += 8; + } + } + + if (i < ny) { + // process leftovers + for (; i < ny; i++) { + float dp = 0; + for (size_t j = 0; j < DIM; j++) { + dp += x[j] * y[j * d_offset]; + } + + // compute y^2 - 2 * (x, y), which is sufficient for looking for the + // lowest distance. + const float distance = y_sqlen[0] - 2 * dp + x_sqlen; + distances[i] = distance; + + y += 1; + y_sqlen += 1; + } + } +} + +template <> +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + // optimized for a few special cases +#define DISPATCH(dval) \ + case dval: \ + return fvec_L2sqr_ny_y_transposed_D( \ + dis, x, y, y_sqlen, d_offset, ny); + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + default: + return fvec_L2sqr_ny_transposed( + dis, x, y, y_sqlen, d, d_offset, ny); + } +#undef DISPATCH +} + +struct AVX2ElementOpIP : public ElementOpIP { + using ElementOpIP::op; + static __m256 op(__m256 x, __m256 y) { + return _mm256_mul_ps(x, y); + } +}; + +struct AVX2ElementOpL2 : public ElementOpL2 { + using ElementOpL2::op; + + static __m256 op(__m256 x, __m256 y) { + __m256 tmp = _mm256_sub_ps(x, y); + return _mm256_mul_ps(tmp, tmp); + } +}; + +/// helper function for AVX2 +inline float horizontal_sum(const __m256 v) { + // add high and low parts + const __m128 v0 = + _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + // perform horizontal sum on v0 + return horizontal_sum(v0); +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D2-vectors per loop. + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); + + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + + for (i = 0; i < ny8 * 8; i += 8) { + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + // load 8x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + + transpose_8x2( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + v0, + v1); + + // compute distances + __m256 distances = _mm256_mul_ps(m0, v0); + distances = _mm256_fmadd_ps(m1, v1, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 16; + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float distance = x0 * y[0] + x1 * y[1]; + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D2-vectors per loop. + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); + + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + + for (i = 0; i < ny8 * 8; i += 8) { + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + // load 8x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + + transpose_8x2( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + v0, + v1); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 16; + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D4-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x4 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + + transpose_8x4( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + v0, + v1, + v2, + v3); + + // compute distances + __m256 distances = _mm256_mul_ps(m0, v0); + distances = _mm256_fmadd_ps(m1, v1, distances); + distances = _mm256_fmadd_ps(m2, v2, distances); + distances = _mm256_fmadd_ps(m3, v3, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 32; + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = AVX2ElementOpIP::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D4-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x4 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + + transpose_8x4( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + v0, + v1, + v2, + v3); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 32; + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = AVX2ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D8-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + const __m256 m4 = _mm256_set1_ps(x[4]); + const __m256 m5 = _mm256_set1_ps(x[5]); + const __m256 m6 = _mm256_set1_ps(x[6]); + const __m256 m7 = _mm256_set1_ps(x[7]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + __m256 v4; + __m256 v5; + __m256 v6; + __m256 v7; + + transpose_8x8( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + _mm256_loadu_ps(y + 4 * 8), + _mm256_loadu_ps(y + 5 * 8), + _mm256_loadu_ps(y + 6 * 8), + _mm256_loadu_ps(y + 7 * 8), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute distances + __m256 distances = _mm256_mul_ps(m0, v0); + distances = _mm256_fmadd_ps(m1, v1, distances); + distances = _mm256_fmadd_ps(m2, v2, distances); + distances = _mm256_fmadd_ps(m3, v3, distances); + distances = _mm256_fmadd_ps(m4, v4, distances); + distances = _mm256_fmadd_ps(m5, v5, distances); + distances = _mm256_fmadd_ps(m6, v6, distances); + distances = _mm256_fmadd_ps(m7, v7, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 64; + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX2ElementOpIP::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D8-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + const __m256 m4 = _mm256_set1_ps(x[4]); + const __m256 m5 = _mm256_set1_ps(x[5]); + const __m256 m6 = _mm256_set1_ps(x[6]); + const __m256 m7 = _mm256_set1_ps(x[7]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + __m256 v4; + __m256 v5; + __m256 v6; + __m256 v7; + + transpose_8x8( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + _mm256_loadu_ps(y + 4 * 8), + _mm256_loadu_ps(y + 5 * 8), + _mm256_loadu_ps(y + 6 * 8), + _mm256_loadu_ps(y + 7 * 8), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + const __m256 d4 = _mm256_sub_ps(m4, v4); + const __m256 d5 = _mm256_sub_ps(m5, v5); + const __m256 d6 = _mm256_sub_ps(m6, v6); + const __m256 d7 = _mm256_sub_ps(m7, v7); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + distances = _mm256_fmadd_ps(d4, d4, distances); + distances = _mm256_fmadd_ps(d5, d5, distances); + distances = _mm256_fmadd_ps(d6, d6, distances); + distances = _mm256_fmadd_ps(d7, d7, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 64; + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX2ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_inner_products_ny( + float* ip, /* output inner product */ + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_inner_products_ny_ref(ip, x, y, d, ny); +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny_ref(dis, x, y, d, ny); +} + +template <> +size_t fvec_L2sqr_ny_nearest_D2( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + // current index being processed + size_t i = 0; + + // min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // process 8 D2-vectors per loop. + const size_t ny8 = ny / 8; + if (ny8 > 0) { + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); + + // track min distance and the closest vector independently + // for each of 8 AVX2 components. + __m256 min_distances = _mm256_set1_ps(HUGE_VALF); + __m256i min_indices = _mm256_set1_epi32(0); + + __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i indices_increment = _mm256_set1_epi32(8); + + // 1 value per register + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + + for (; i < ny8 * 8; i += 8) { + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + __m256 v0; + __m256 v1; + + transpose_8x2( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + v0, + v1); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + + // compare the new distances to the min distances + // for each of 8 AVX2 components. + __m256 comparison = + _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); + + // update min distances and indices with closest vectors if needed. + min_distances = _mm256_min_ps(distances, min_distances); + min_indices = _mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(current_indices), + _mm256_castsi256_ps(min_indices), + comparison)); + + // update current indices values. Basically, +8 to each of the + // 8 AVX2 components. + current_indices = + _mm256_add_epi32(current_indices, indices_increment); + + // scroll y forward (8 vectors 2 DIM each). + y += 16; + } + + // dump values and find the minimum distance / minimum index + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + _mm256_storeu_ps(min_distances_scalar, min_distances); + _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 8; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // process leftovers. + // the following code is not optimal, but it is rarely invoked. + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_D4( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + // current index being processed + size_t i = 0; + + // min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // process 8 D4-vectors per loop. + const size_t ny8 = ny / 8; + + if (ny8 > 0) { + // track min distance and the closest vector independently + // for each of 8 AVX2 components. + __m256 min_distances = _mm256_set1_ps(HUGE_VALF); + __m256i min_indices = _mm256_set1_epi32(0); + + __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i indices_increment = _mm256_set1_epi32(8); + + // 1 value per register + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + + for (; i < ny8 * 8; i += 8) { + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + + transpose_8x4( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + v0, + v1, + v2, + v3); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + + // compare the new distances to the min distances + // for each of 8 AVX2 components. + __m256 comparison = + _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); + + // update min distances and indices with closest vectors if needed. + min_distances = _mm256_min_ps(distances, min_distances); + min_indices = _mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(current_indices), + _mm256_castsi256_ps(min_indices), + comparison)); + + // update current indices values. Basically, +8 to each of the + // 8 AVX2 components. + current_indices = + _mm256_add_epi32(current_indices, indices_increment); + + // scroll y forward (8 vectors 4 DIM each). + y += 32; + } + + // dump values and find the minimum distance / minimum index + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + _mm256_storeu_ps(min_distances_scalar, min_distances); + _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 8; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + const float distance = horizontal_sum(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_D8( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + // current index being processed + size_t i = 0; + + // min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // process 8 D8-vectors per loop. + const size_t ny8 = ny / 8; + if (ny8 > 0) { + // track min distance and the closest vector independently + // for each of 8 AVX2 components. + __m256 min_distances = _mm256_set1_ps(HUGE_VALF); + __m256i min_indices = _mm256_set1_epi32(0); + + __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i indices_increment = _mm256_set1_epi32(8); + + // 1 value per register + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + + const __m256 m4 = _mm256_set1_ps(x[4]); + const __m256 m5 = _mm256_set1_ps(x[5]); + const __m256 m6 = _mm256_set1_ps(x[6]); + const __m256 m7 = _mm256_set1_ps(x[7]); + + for (; i < ny8 * 8; i += 8) { + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + __m256 v4; + __m256 v5; + __m256 v6; + __m256 v7; + + transpose_8x8( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + _mm256_loadu_ps(y + 4 * 8), + _mm256_loadu_ps(y + 5 * 8), + _mm256_loadu_ps(y + 6 * 8), + _mm256_loadu_ps(y + 7 * 8), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + const __m256 d4 = _mm256_sub_ps(m4, v4); + const __m256 d5 = _mm256_sub_ps(m5, v5); + const __m256 d6 = _mm256_sub_ps(m6, v6); + const __m256 d7 = _mm256_sub_ps(m7, v7); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + distances = _mm256_fmadd_ps(d4, d4, distances); + distances = _mm256_fmadd_ps(d5, d5, distances); + distances = _mm256_fmadd_ps(d6, d6, distances); + distances = _mm256_fmadd_ps(d7, d7, distances); + + // compare the new distances to the min distances + // for each of 8 AVX2 components. + __m256 comparison = + _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); + + // update min distances and indices with closest vectors if needed. + min_distances = _mm256_min_ps(distances, min_distances); + min_indices = _mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(current_indices), + _mm256_castsi256_ps(min_indices), + comparison)); + + // update current indices values. Basically, +8 to each of the + // 8 AVX2 components. + current_indices = + _mm256_add_epi32(current_indices, indices_increment); + + // scroll y forward (8 vectors 8 DIM each). + y += 64; + } + + // dump values and find the minimum distance / minimum index + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + _mm256_storeu_ps(min_distances_scalar, min_distances); + _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 8; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX2ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + const float distance = horizontal_sum(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + return fvec_L2sqr_ny_nearest_x86( + distances_tmp_buffer, + x, + y, + d, + ny, + &fvec_L2sqr_ny_nearest_D2, + &fvec_L2sqr_ny_nearest_D4, + &fvec_L2sqr_ny_nearest_D8); +} + +template +size_t fvec_L2sqr_ny_nearest_y_transposed_D( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + const size_t d_offset, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + // current index being processed + size_t i = 0; + + // min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // process 8 vectors per loop. + const size_t ny8 = ny / 8; + + if (ny8 > 0) { + // track min distance and the closest vector independently + // for each of 8 AVX2 components. + __m256 min_distances = _mm256_set1_ps(HUGE_VALF); + __m256i min_indices = _mm256_set1_epi32(0); + + __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i indices_increment = _mm256_set1_epi32(8); + + // m[i] = (2 * x[i], ... 2 * x[i]) + __m256 m[DIM]; + for (size_t j = 0; j < DIM; j++) { + m[j] = _mm256_set1_ps(x[j]); + m[j] = _mm256_add_ps(m[j], m[j]); + } + + for (; i < ny8 * 8; i += 8) { + // collect dim 0 for 8 D4-vectors. + const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset); + // compute dot products + __m256 dp = _mm256_mul_ps(m[0], v0); + + for (size_t j = 1; j < DIM; j++) { + // collect dim j for 8 D4-vectors. + const __m256 vj = _mm256_loadu_ps(y + j * d_offset); + dp = _mm256_fmadd_ps(m[j], vj, dp); + } + + // compute y^2 - (2 * x, y), which is sufficient for looking for the + // lowest distance. + // x^2 is the constant that can be avoided. + const __m256 distances = + _mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp); + + // compare the new distances to the min distances + // for each of 8 AVX2 components. + const __m256 comparison = + _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); + + // update min distances and indices with closest vectors if needed. + min_distances = + _mm256_blendv_ps(distances, min_distances, comparison); + min_indices = _mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(current_indices), + _mm256_castsi256_ps(min_indices), + comparison)); + + // update current indices values. Basically, +8 to each of the + // 8 AVX2 components. + current_indices = + _mm256_add_epi32(current_indices, indices_increment); + + // scroll y and y_sqlen forward. + y += 8; + y_sqlen += 8; + } + + // dump values and find the minimum distance / minimum index + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + _mm256_storeu_ps(min_distances_scalar, min_distances); + _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 8; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // process leftovers + for (; i < ny; i++) { + float dp = 0; + for (size_t j = 0; j < DIM; j++) { + dp += x[j] * y[j * d_offset]; + } + + // compute y^2 - 2 * (x, y), which is sufficient for looking for the + // lowest distance. + const float distance = y_sqlen[0] - 2 * dp; + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + + y += 1; + y_sqlen += 1; + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { +// optimized for a few special cases +#define DISPATCH(dval) \ + case dval: \ + return fvec_L2sqr_ny_nearest_y_transposed_D( \ + distances_tmp_buffer, x, y, y_sqlen, d_offset, ny); + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + default: + return fvec_L2sqr_ny_nearest_y_transposed( + distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); + } +#undef DISPATCH +} + +template <> +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + return fvec_madd_and_argmin_sse(n, a, bf, b, c); +} + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_avx512.cpp b/faiss/utils/simd_impl/distances_avx512.cpp new file mode 100644 index 0000000000..06d5b399f4 --- /dev/null +++ b/faiss/utils/simd_impl/distances_avx512.cpp @@ -0,0 +1,1092 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#define AUTOVEC_LEVEL SIMDLevel::AVX512 +#include +#include +#include + +namespace faiss { + +template <> +void fvec_madd( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + const size_t n16 = n / 16; + const size_t n_for_masking = n % 16; + + const __m512 bfmm = _mm512_set1_ps(bf); + + size_t idx = 0; + for (idx = 0; idx < n16 * 16; idx += 16) { + const __m512 ax = _mm512_loadu_ps(a + idx); + const __m512 bx = _mm512_loadu_ps(b + idx); + const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); + _mm512_storeu_ps(c + idx, abmul); + } + + if (n_for_masking > 0) { + const __mmask16 mask = (1 << n_for_masking) - 1; + + const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx); + const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx); + const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); + _mm512_mask_storeu_ps(c + idx, mask, abmul); + } +} + +template +void fvec_L2sqr_ny_y_transposed_D( + float* distances, + const float* x, + const float* y, + const float* y_sqlen, + const size_t d_offset, + size_t ny) { + // current index being processed + size_t i = 0; + + // squared length of x + float x_sqlen = 0; + for (size_t j = 0; j < DIM; j++) { + x_sqlen += x[j] * x[j]; + } + + // process 16 vectors per loop + const size_t ny16 = ny / 16; + + if (ny16 > 0) { + // m[i] = (2 * x[i], ... 2 * x[i]) + __m512 m[DIM]; + for (size_t j = 0; j < DIM; j++) { + m[j] = _mm512_set1_ps(x[j]); + m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j] + } + + __m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen); + + for (; i < ny16 * 16; i += 16) { + // Load vectors for 16 dimensions + __m512 v[DIM]; + for (size_t j = 0; j < DIM; j++) { + v[j] = _mm512_loadu_ps(y + j * d_offset); + } + + // Compute dot products + __m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm); + for (size_t j = 1; j < DIM; j++) { + dp = _mm512_fnmadd_ps(m[j], v[j], dp); + } + + // Compute y^2 - (2 * x, y) + x^2 + __m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp); + + _mm512_storeu_ps(distances + i, distances_v); + + // Scroll y and y_sqlen forward + y += 16; + y_sqlen += 16; + } + } + + if (i < ny) { + // Process leftovers + for (; i < ny; i++) { + float dp = 0; + for (size_t j = 0; j < DIM; j++) { + dp += x[j] * y[j * d_offset]; + } + + // Compute y^2 - 2 * (x, y), which is sufficient for looking for the + // lowest distance. + const float distance = y_sqlen[0] - 2 * dp + x_sqlen; + distances[i] = distance; + + y += 1; + y_sqlen += 1; + } + } +} + +template <> +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + // optimized for a few special cases +#define DISPATCH(dval) \ + case dval: \ + return fvec_L2sqr_ny_y_transposed_D( \ + dis, x, y, y_sqlen, d_offset, ny); + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + default: + return fvec_L2sqr_ny_transposed( + dis, x, y, y_sqlen, d, d_offset, ny); + } +#undef DISPATCH +} + +struct AVX512ElementOpIP : public ElementOpIP { + using ElementOpIP::op; + static __m512 op(__m512 x, __m512 y) { + return _mm512_mul_ps(x, y); + } + static __m256 op(__m256 x, __m256 y) { + return _mm256_mul_ps(x, y); + } +}; + +struct AVX512ElementOpL2 : public ElementOpL2 { + using ElementOpL2::op; + static __m512 op(__m512 x, __m512 y) { + __m512 tmp = _mm512_sub_ps(x, y); + return _mm512_mul_ps(tmp, tmp); + } + static __m256 op(__m256 x, __m256 y) { + __m256 tmp = _mm256_sub_ps(x, y); + return _mm256_mul_ps(tmp, tmp); + } +}; + +/// helper function for AVX512 +inline float horizontal_sum(const __m512 v) { + // performs better than adding the high and low parts + return _mm512_reduce_add_ps(v); +} + +inline float horizontal_sum(const __m256 v) { + // add high and low parts + const __m128 v0 = + _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + // perform horizontal sum on v0 + return horizontal_sum(v0); +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D2-vectors per loop. + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + + for (i = 0; i < ny16 * 16; i += 16) { + _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); + + // load 16x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + + transpose_16x2( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + v0, + v1); + + // compute distances (dot product) + __m512 distances = _mm512_mul_ps(m0, v0); + distances = _mm512_fmadd_ps(m1, v1, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 32; // move to the next set of 16x2 elements + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float distance = x0 * y[0] + x1 * y[1]; + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D2-vectors per loop. + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + + for (i = 0; i < ny16 * 16; i += 16) { + _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); + + // load 16x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + + transpose_16x2( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + v0, + v1); + + // compute differences + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + + // compute squares of differences + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 32; // move to the next set of 16x2 elements + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D4-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x4 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + + transpose_16x4( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + v0, + v1, + v2, + v3); + + // compute distances + __m512 distances = _mm512_mul_ps(m0, v0); + distances = _mm512_fmadd_ps(m1, v1, distances); + distances = _mm512_fmadd_ps(m2, v2, distances); + distances = _mm512_fmadd_ps(m3, v3, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 64; // move to the next set of 16x4 elements + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = AVX512ElementOpIP::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D4-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x4 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + + transpose_16x4( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + v0, + v1, + v2, + v3); + + // compute differences + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + + // compute squares of differences + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 64; // move to the next set of 16x4 elements + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = AVX512ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D16-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + const __m512 m4 = _mm512_set1_ps(x[4]); + const __m512 m5 = _mm512_set1_ps(x[5]); + const __m512 m6 = _mm512_set1_ps(x[6]); + const __m512 m7 = _mm512_set1_ps(x[7]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + __m512 v4; + __m512 v5; + __m512 v6; + __m512 v7; + + transpose_16x8( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + _mm512_loadu_ps(y + 4 * 16), + _mm512_loadu_ps(y + 5 * 16), + _mm512_loadu_ps(y + 6 * 16), + _mm512_loadu_ps(y + 7 * 16), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute distances + __m512 distances = _mm512_mul_ps(m0, v0); + distances = _mm512_fmadd_ps(m1, v1, distances); + distances = _mm512_fmadd_ps(m2, v2, distances); + distances = _mm512_fmadd_ps(m3, v3, distances); + distances = _mm512_fmadd_ps(m4, v4, distances); + distances = _mm512_fmadd_ps(m5, v5, distances); + distances = _mm512_fmadd_ps(m6, v6, distances); + distances = _mm512_fmadd_ps(m7, v7, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 128; // 16 floats * 8 rows + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX512ElementOpIP::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D16-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + const __m512 m4 = _mm512_set1_ps(x[4]); + const __m512 m5 = _mm512_set1_ps(x[5]); + const __m512 m6 = _mm512_set1_ps(x[6]); + const __m512 m7 = _mm512_set1_ps(x[7]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + __m512 v4; + __m512 v5; + __m512 v6; + __m512 v7; + + transpose_16x8( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + _mm512_loadu_ps(y + 4 * 16), + _mm512_loadu_ps(y + 5 * 16), + _mm512_loadu_ps(y + 6 * 16), + _mm512_loadu_ps(y + 7 * 16), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute differences + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + const __m512 d4 = _mm512_sub_ps(m4, v4); + const __m512 d5 = _mm512_sub_ps(m5, v5); + const __m512 d6 = _mm512_sub_ps(m6, v6); + const __m512 d7 = _mm512_sub_ps(m7, v7); + + // compute squares of differences + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + distances = _mm512_fmadd_ps(d4, d4, distances); + distances = _mm512_fmadd_ps(d5, d5, distances); + distances = _mm512_fmadd_ps(d6, d6, distances); + distances = _mm512_fmadd_ps(d7, d7, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 128; // 16 floats * 8 rows + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX512ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_inner_products_ny( + float* ip, /* output inner product */ + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_inner_products_ny_ref(ip, x, y, d, ny); +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny_ref(dis, x, y, d, ny); +} + +template <> +size_t fvec_L2sqr_ny_nearest_D2( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + size_t i = 0; + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + const size_t ny16 = ny / 16; + if (ny16 > 0) { + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + + for (; i < ny16 * 16; i += 16) { + _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); + + __m512 v0; + __m512 v1; + + transpose_16x2( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + v0, + v1); + + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + + __mmask16 comparison = + _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); + + min_distances = _mm512_min_ps(distances, min_distances); + min_indices = _mm512_mask_blend_epi32( + comparison, min_indices, current_indices); + + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + y += 32; + } + + alignas(64) float min_distances_scalar[16]; + alignas(64) uint32_t min_indices_scalar[16]; + _mm512_store_ps(min_distances_scalar, min_distances); + _mm512_store_epi32(min_indices_scalar, min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_D4( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + size_t i = 0; + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + const size_t ny16 = ny / 16; + + if (ny16 > 0) { + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + for (; i < ny16 * 16; i += 16) { + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + + transpose_16x4( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + v0, + v1, + v2, + v3); + + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + + __mmask16 comparison = + _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); + + min_distances = _mm512_min_ps(distances, min_distances); + min_indices = _mm512_mask_blend_epi32( + comparison, min_indices, current_indices); + + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + y += 64; + } + + alignas(64) float min_distances_scalar[16]; + alignas(64) uint32_t min_indices_scalar[16]; + _mm512_store_ps(min_distances_scalar, min_distances); + _mm512_store_epi32(min_indices_scalar, min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + const float distance = horizontal_sum(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_D8( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + size_t i = 0; + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + const size_t ny16 = ny / 16; + if (ny16 > 0) { + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + const __m512 m4 = _mm512_set1_ps(x[4]); + const __m512 m5 = _mm512_set1_ps(x[5]); + const __m512 m6 = _mm512_set1_ps(x[6]); + const __m512 m7 = _mm512_set1_ps(x[7]); + + for (; i < ny16 * 16; i += 16) { + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + __m512 v4; + __m512 v5; + __m512 v6; + __m512 v7; + + transpose_16x8( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + _mm512_loadu_ps(y + 4 * 16), + _mm512_loadu_ps(y + 5 * 16), + _mm512_loadu_ps(y + 6 * 16), + _mm512_loadu_ps(y + 7 * 16), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + const __m512 d4 = _mm512_sub_ps(m4, v4); + const __m512 d5 = _mm512_sub_ps(m5, v5); + const __m512 d6 = _mm512_sub_ps(m6, v6); + const __m512 d7 = _mm512_sub_ps(m7, v7); + + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + distances = _mm512_fmadd_ps(d4, d4, distances); + distances = _mm512_fmadd_ps(d5, d5, distances); + distances = _mm512_fmadd_ps(d6, d6, distances); + distances = _mm512_fmadd_ps(d7, d7, distances); + + __mmask16 comparison = + _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); + + min_distances = _mm512_min_ps(distances, min_distances); + min_indices = _mm512_mask_blend_epi32( + comparison, min_indices, current_indices); + + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + y += 128; + } + + alignas(64) float min_distances_scalar[16]; + alignas(64) uint32_t min_indices_scalar[16]; + _mm512_store_ps(min_distances_scalar, min_distances); + _mm512_store_epi32(min_indices_scalar, min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX512ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + const float distance = horizontal_sum(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + return fvec_L2sqr_ny_nearest_x86( + distances_tmp_buffer, + x, + y, + d, + ny, + &fvec_L2sqr_ny_nearest_D2, + &fvec_L2sqr_ny_nearest_D4, + &fvec_L2sqr_ny_nearest_D8); +} + +template <> +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + return fvec_L2sqr_ny_nearest_y_transposed( + distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); +} + +// TODO: Following functions are not used in the current codebase. Check AVX2 , +// respective implementation has been used +template +size_t fvec_L2sqr_ny_nearest_y_transposed_D( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + const size_t d_offset, + size_t ny) { + // This implementation does not use distances_tmp_buffer. + + // Current index being processed + size_t i = 0; + + // Min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // Process 16 vectors per loop + const size_t ny16 = ny / 16; + + if (ny16 > 0) { + // Track min distance and the closest vector independently + // for each of 16 AVX-512 components. + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + // m[i] = (2 * x[i], ... 2 * x[i]) + __m512 m[DIM]; + for (size_t j = 0; j < DIM; j++) { + m[j] = _mm512_set1_ps(x[j]); + m[j] = _mm512_add_ps(m[j], m[j]); + } + + for (; i < ny16 * 16; i += 16) { + // Compute dot products + const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset); + __m512 dp = _mm512_mul_ps(m[0], v0); + for (size_t j = 1; j < DIM; j++) { + const __m512 vj = _mm512_loadu_ps(y + j * d_offset); + dp = _mm512_fmadd_ps(m[j], vj, dp); + } + + // Compute y^2 - (2 * x, y), which is sufficient for looking for the + // lowest distance. + // x^2 is the constant that can be avoided. + const __m512 distances = + _mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp); + + // Compare the new distances to the min distances + __mmask16 comparison = + _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); + + // Update min distances and indices with closest vectors if needed + min_distances = + _mm512_mask_blend_ps(comparison, distances, min_distances); + min_indices = _mm512_castps_si512(_mm512_mask_blend_ps( + comparison, + _mm512_castsi512_ps(current_indices), + _mm512_castsi512_ps(min_indices))); + + // Update current indices values. Basically, +16 to each of the 16 + // AVX-512 components. + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + // Scroll y and y_sqlen forward. + y += 16; + y_sqlen += 16; + } + + // Dump values and find the minimum distance / minimum index + float min_distances_scalar[16]; + uint32_t min_indices_scalar[16]; + _mm512_storeu_ps(min_distances_scalar, min_distances); + _mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // Process leftovers + for (; i < ny; i++) { + float dp = 0; + for (size_t j = 0; j < DIM; j++) { + dp += x[j] * y[j * d_offset]; + } + + // Compute y^2 - 2 * (x, y), which is sufficient for looking for the + // lowest distance. + const float distance = y_sqlen[0] - 2 * dp; + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + + y += 1; + y_sqlen += 1; + } + } + + return current_min_index; +} + +template <> +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + return fvec_madd_and_argmin_sse(n, a, bf, b, c); +} + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_sse-inl.h b/faiss/utils/simd_impl/distances_sse-inl.h new file mode 100644 index 0000000000..a5151750cb --- /dev/null +++ b/faiss/utils/simd_impl/distances_sse-inl.h @@ -0,0 +1,385 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace faiss { + +[[maybe_unused]] static inline void fvec_madd_sse( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + n >>= 2; + __m128 bf4 = _mm_set_ps1(bf); + __m128* a4 = (__m128*)a; + __m128* b4 = (__m128*)b; + __m128* c4 = (__m128*)c; + + while (n--) { + *c4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4)); + b4++; + a4++; + c4++; + } +} + +/// helper function +inline float horizontal_sum(const __m128 v) { + // say, v is [x0, x1, x2, x3] + + // v0 is [x2, x3, ..., ...] + const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); + // v1 is [x0 + x2, x1 + x3, ..., ...] + const __m128 v1 = _mm_add_ps(v, v0); + // v2 is [x1 + x3, ..., .... ,...] + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + // v3 is [x0 + x1 + x2 + x3, ..., ..., ...] + const __m128 v3 = _mm_add_ps(v1, v2); + // return v3[0] + return _mm_cvtss_f32(v3); +} + +/// Function that does a component-wise operation between x and y +/// to compute inner products +struct ElementOpIP { + static float op(float x, float y) { + return x * y; + } + + static __m128 op(__m128 x, __m128 y) { + return _mm_mul_ps(x, y); + } +}; + +/// Function that does a component-wise operation between x and y +/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny +/// functions below +struct ElementOpL2 { + static float op(float x, float y) { + float tmp = x - y; + return tmp * tmp; + } + + static __m128 op(__m128 x, __m128 y) { + __m128 tmp = _mm_sub_ps(x, y); + return _mm_mul_ps(tmp, tmp); + } +}; + +template +void fvec_op_ny_D1(float* dis, const float* x, const float* y, size_t ny) { + float x0s = x[0]; + __m128 x0 = _mm_set_ps(x0s, x0s, x0s, x0s); + + size_t i; + for (i = 0; i + 3 < ny; i += 4) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = _mm_cvtss_f32(accu); + __m128 tmp = _mm_shuffle_ps(accu, accu, 1); + dis[i + 1] = _mm_cvtss_f32(tmp); + tmp = _mm_shuffle_ps(accu, accu, 2); + dis[i + 2] = _mm_cvtss_f32(tmp); + tmp = _mm_shuffle_ps(accu, accu, 3); + dis[i + 3] = _mm_cvtss_f32(tmp); + } + while (i < ny) { // handle non-multiple-of-4 case + dis[i++] = ElementOp::op(x0s, *y++); + } +} + +template +void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) { + __m128 x0 = _mm_set_ps(x[1], x[0], x[1], x[0]); + + size_t i; + for (i = 0; i + 1 < ny; i += 2) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + accu = _mm_hadd_ps(accu, accu); + dis[i] = _mm_cvtss_f32(accu); + accu = _mm_shuffle_ps(accu, accu, 3); + dis[i + 1] = _mm_cvtss_f32(accu); + } + if (i < ny) { // handle odd case + dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]); + } +} + +template +void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) { + __m128 x0 = _mm_loadu_ps(x); + + for (size_t i = 0; i < ny; i++) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } +} + +template +void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) { + __m128 x0 = _mm_loadu_ps(x); + __m128 x1 = _mm_loadu_ps(x + 4); + + for (size_t i = 0; i < ny; i++) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y))); + y += 4; + accu = _mm_hadd_ps(accu, accu); + accu = _mm_hadd_ps(accu, accu); + dis[i] = _mm_cvtss_f32(accu); + } +} + +template +void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) { + __m128 x0 = _mm_loadu_ps(x); + __m128 x1 = _mm_loadu_ps(x + 4); + __m128 x2 = _mm_loadu_ps(x + 8); + + for (size_t i = 0; i < ny; i++) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y))); + y += 4; + accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y))); + y += 4; + dis[i] = horizontal_sum(accu); + } +} + +template +void fvec_inner_products_ny_ref( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { +#define DISPATCH(dval) \ + case dval: \ + fvec_op_ny_D##dval(dis, x, y, ny); \ + return; + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + DISPATCH(12) + default: + fvec_inner_products_ny(dis, x, y, d, ny); + return; + } +#undef DISPATCH +} + +template +void fvec_L2sqr_ny_ref( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + // optimized for a few special cases + +#define DISPATCH(dval) \ + case dval: \ + fvec_op_ny_D##dval(dis, x, y, ny); \ + return; + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + DISPATCH(12) + default: + fvec_L2sqr_ny(dis, x, y, d, ny); + return; + } +#undef DISPATCH +} + +template +size_t fvec_L2sqr_ny_nearest_D2( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny); + +template +size_t fvec_L2sqr_ny_nearest_D4( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny); + +template +size_t fvec_L2sqr_ny_nearest_D8( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny); + +template +size_t fvec_L2sqr_ny_nearest_x86( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny, + size_t (*fvec_L2sqr_ny_nearest_D2_func)( + float*, + const float*, + const float*, + size_t) = &fvec_L2sqr_ny_nearest_D2, + size_t (*fvec_L2sqr_ny_nearest_D4_func)( + float*, + const float*, + const float*, + size_t) = &fvec_L2sqr_ny_nearest_D4, + size_t (*fvec_L2sqr_ny_nearest_D8_func)( + float*, + const float*, + const float*, + size_t) = &fvec_L2sqr_ny_nearest_D8); + +template +size_t fvec_L2sqr_ny_nearest_x86( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny, + size_t (*fvec_L2sqr_ny_nearest_D2_func)( + float*, + const float*, + const float*, + size_t), + size_t (*fvec_L2sqr_ny_nearest_D4_func)( + float*, + const float*, + const float*, + size_t), + size_t (*fvec_L2sqr_ny_nearest_D8_func)( + float*, + const float*, + const float*, + size_t)) { + switch (d) { + case 2: + return fvec_L2sqr_ny_nearest_D2_func( + distances_tmp_buffer, x, y, ny); + case 4: + return fvec_L2sqr_ny_nearest_D4_func( + distances_tmp_buffer, x, y, ny); + case 8: + return fvec_L2sqr_ny_nearest_D8_func( + distances_tmp_buffer, x, y, ny); + } + + return fvec_L2sqr_ny_nearest( + distances_tmp_buffer, x, y, d, ny); +} + +template +inline size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny); + +static inline int fvec_madd_and_argmin_sse_ref( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + n >>= 2; + __m128 bf4 = _mm_set_ps1(bf); + __m128 vmin4 = _mm_set_ps1(1e20); + __m128i imin4 = _mm_set1_epi32(-1); + __m128i idx4 = _mm_set_epi32(3, 2, 1, 0); + __m128i inc4 = _mm_set1_epi32(4); + __m128* a4 = (__m128*)a; + __m128* b4 = (__m128*)b; + __m128* c4 = (__m128*)c; + + while (n--) { + __m128 vc4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4)); + *c4 = vc4; + __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); + // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower! + + imin4 = _mm_or_si128( + _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); + vmin4 = _mm_min_ps(vmin4, vc4); + b4++; + a4++; + c4++; + idx4 = _mm_add_epi32(idx4, inc4); + } + + // 4 values -> 2 + { + idx4 = _mm_shuffle_epi32(imin4, 3 << 2 | 2); + __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 3 << 2 | 2); + __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); + imin4 = _mm_or_si128( + _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); + vmin4 = _mm_min_ps(vmin4, vc4); + } + // 2 values -> 1 + { + idx4 = _mm_shuffle_epi32(imin4, 1); + __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 1); + __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); + imin4 = _mm_or_si128( + _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); + // vmin4 = _mm_min_ps (vmin4, vc4); + } + return _mm_cvtsi128_si32(imin4); +} + +static inline int fvec_madd_and_argmin_sse( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) + return fvec_madd_and_argmin_sse_ref(n, a, bf, b, c); + + return fvec_madd_and_argmin(n, a, bf, b, c); +} + +// reads 0 <= d < 4 floats as __m128 +static inline __m128 masked_read(int d, const float* x) { + assert(0 <= d && d < 4); + ALIGNED(16) float buf[4] = {0, 0, 0, 0}; + switch (d) { + case 3: + buf[2] = x[2]; + [[fallthrough]]; + case 2: + buf[1] = x[1]; + [[fallthrough]]; + case 1: + buf[0] = x[0]; + } + return _mm_load_ps(buf); + // cannot use AVX2 _mm_mask_set1_epi32 +} + +} // namespace faiss diff --git a/faiss/utils/simd_levels.cpp b/faiss/utils/simd_levels.cpp index 887225ee3b..3f1769b289 100644 --- a/faiss/utils/simd_levels.cpp +++ b/faiss/utils/simd_levels.cpp @@ -125,10 +125,9 @@ SIMDLevel SIMDConfig::auto_detect_simd_level() { #if defined(__aarch64__) && defined(__ARM_NEON) && \ defined(COMPILE_SIMD_ARM_NEON) - // ARM NEON is standard on aarch64, so we can assume it's available + // ARM NEON is standard on aarch64 supported_simd_levels().insert(SIMDLevel::ARM_NEON); level = SIMDLevel::ARM_NEON; - // TODO: Add ARM SVE detection when needed // For now, we default to ARM_NEON as it's universally supported on aarch64 #endif diff --git a/faiss/utils/simd_levels.h b/faiss/utils/simd_levels.h index ad3d0b289d..95b2decc0b 100644 --- a/faiss/utils/simd_levels.h +++ b/faiss/utils/simd_levels.h @@ -61,7 +61,7 @@ struct SIMDConfig { #ifdef COMPILE_SIMD_AVX512 #define DISPATCH_SIMDLevel_AVX512(f, ...) \ - case SIMDLevel::AVX512F: \ + case SIMDLevel::AVX512: \ return f(__VA_ARGS__) #else #define DISPATCH_SIMDLevel_AVX512(f, ...) diff --git a/tests/test_distances_simd.cpp b/tests/test_distances_simd.cpp index d9c8578fd7..dda33c3e72 100644 --- a/tests/test_distances_simd.cpp +++ b/tests/test_distances_simd.cpp @@ -39,104 +39,352 @@ void fvec_L2sqr_ny_ref( } } -// test templated versions of fvec_L2sqr_ny -TEST(TestFvecL2sqrNy, D2) { - // we're using int values in order to get 100% accurate - // results with floats. - std::default_random_engine rng(123); - std::uniform_int_distribution u(0, 32); +void remove_simd_level_if_exists( + std::unordered_set& levels, + faiss::SIMDLevel level) { + std::erase_if( + levels, [level](faiss::SIMDLevel elem) { return elem == level; }); +} - for (const auto dim : {2, 4, 8, 12}) { - std::vector x(dim, 0); - for (size_t i = 0; i < x.size(); i++) { - x[i] = u(rng); +class DistancesSIMDTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + original_simd_level = faiss::SIMDConfig::get_level(); + std::iota(dims.begin(), dims.end(), 1); + + ntests = 4; + + simd_level = GetParam(); + faiss::SIMDConfig::set_level(simd_level); + + EXPECT_EQ(faiss::SIMDConfig::get_level(), simd_level); + + rng = std::default_random_engine(123); + uniform = std::uniform_int_distribution(0, 32); + } + + void TearDown() override { + faiss::SIMDConfig::set_level(original_simd_level); + } + + std::tuple, std::vector>> + SetupTestData(int dims, int ny) { + std::vector x(dims); + std::vector> y(ny, std::vector(dims)); + + for (size_t i = 0; i < dims; i++) { + x[i] = uniform(rng); + for (size_t j = 0; j < ny; j++) { + y[j][i] = uniform(rng); + } + } + return std::make_tuple(x, y); + } + + std::vector flatten_2d_vector( + const std::vector>& v) { + std::vector flat_v; + for (const auto& vec : v) { + flat_v.insert(flat_v.end(), vec.begin(), vec.end()); + } + return flat_v; + } + + faiss::SIMDLevel simd_level = faiss::SIMDLevel::NONE; + faiss::SIMDLevel original_simd_level = faiss::SIMDLevel::NONE; + std::default_random_engine rng; + std::uniform_int_distribution uniform; + + std::vector dims = {128}; + int ntests = 1; +}; + +TEST_P(DistancesSIMDTest, LinfDistance_chebyshev_distance) { + for (int i = 0; i < ntests; ++i) { // repeat tests + for (const auto dim : dims) { // test different dimensions + int ny = 1; + auto [x, y] = SetupTestData(dim, ny); + for (int k = 0; k < ny; ++k) { // test different vectors + float distance = faiss::fvec_Linf(x.data(), y[k].data(), dim); + float ref_distance = 0; + + for (int j = 0; j < dim; ++j) { + ref_distance = + std::max(ref_distance, std::abs(x[j] - y[k][j])); + } + ASSERT_EQ(distance, ref_distance); + } } + } +} - for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) { - std::vector y(nrows * dim); - for (size_t i = 0; i < y.size(); i++) { - y[i] = u(rng); +TEST_P(DistancesSIMDTest, inner_product_batch_4) { + for (int i = 0; i < ntests; ++i) { + int dim = 128; + int ny = 4; + auto [x, y] = SetupTestData(dim, ny); + + std::vector true_distances(ny, 0.F); + for (int j = 0; j < ny; ++j) { + for (int k = 0; k < dim; ++k) { + true_distances[j] += x[k] * y[j][k]; } + } - std::vector distances(nrows, 0); - faiss::fvec_L2sqr_ny( - distances.data(), x.data(), y.data(), dim, nrows); + std::vector actual_distances(ny, 0.F); + faiss::fvec_inner_product_batch_4( + x.data(), + y[0].data(), + y[1].data(), + y[2].data(), + y[3].data(), + dim, + actual_distances[0], + actual_distances[1], + actual_distances[2], + actual_distances[3]); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_inner_product_batch4 results for test = " + << i; + } +} + +TEST_P(DistancesSIMDTest, fvec_L2sqr) { + for (int i = 0; i < ntests; ++i) { + int ny = 1; + for (const auto dim : dims) { + auto [x, y] = SetupTestData(dim, ny); + float true_distance = 0.F; + for (int k = 0; k < dim; ++k) { + const float tmp = x[k] - y[0][k]; + true_distance += tmp * tmp; + } - std::vector distances_ref(nrows, 0); - fvec_L2sqr_ny_ref( - distances_ref.data(), x.data(), y.data(), dim, nrows); + float actual_distance = + faiss::fvec_L2sqr(x.data(), y[0].data(), dim); - ASSERT_EQ(distances, distances_ref) - << "Mismatching results for dim = " << dim - << ", nrows = " << nrows; + ASSERT_EQ(actual_distance, true_distance) + << "Mismatching fvec_L2sqr results for test = " << i; } } } -// fvec_inner_products_ny -TEST(TestFvecInnerProductsNy, D2) { - // we're using int values in order to get 100% accurate - // results with floats. - std::default_random_engine rng(123); - std::uniform_int_distribution u(0, 32); +TEST_P(DistancesSIMDTest, L2sqr_batch_4) { + for (int i = 0; i < ntests; ++i) { + int dim = 128; + int ny = 4; + auto [x, y] = SetupTestData(dim, ny); + + std::vector true_distances(ny, 0.F); + for (int j = 0; j < ny; ++j) { + for (int k = 0; k < dim; ++k) { + const float tmp = x[k] - y[j][k]; + true_distances[j] += tmp * tmp; + } + } + + std::vector actual_distances(ny, 0.F); + faiss::fvec_L2sqr_batch_4( + x.data(), + y[0].data(), + y[1].data(), + y[2].data(), + y[3].data(), + dim, + actual_distances[0], + actual_distances[1], + actual_distances[2], + actual_distances[3]); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_L2sqr_batch_4 results for test = " << i; + } +} +TEST_P(DistancesSIMDTest, fvec_L2sqr_ny) { for (const auto dim : {2, 4, 8, 12}) { - std::vector x(dim, 0); - for (size_t i = 0; i < x.size(); i++) { - x[i] = u(rng); - } + for (const auto ny : {1, 2, 5, 10, 15, 20, 25}) { + auto [x, y] = SetupTestData(dim, ny); + + std::vector actual_distances(ny, 0.F); - for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) { - std::vector y(nrows * dim); - for (size_t i = 0; i < y.size(); i++) { - y[i] = u(rng); + std::vector flat_y; + for (auto y_ : y) { + flat_y.insert(flat_y.end(), y_.begin(), y_.end()); } - std::vector distances(nrows, 0); + std::vector true_distances(ny, 0.F); + for (int i = 0; i < ny; ++i) { + for (int k = 0; k < dim; ++k) { + const float tmp = x[k] - y[i][k]; + true_distances[i] += tmp * tmp; + } + } + + faiss::fvec_L2sqr_ny( + actual_distances.data(), x.data(), flat_y.data(), dim, ny); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_L2sqr_ny results for dim = " << dim + << ", ny = " << ny; + } + } +} + +TEST_P(DistancesSIMDTest, fvec_inner_products_ny) { + for (const auto dim : {2, 4, 8, 12}) { + for (const auto ny : {1, 2, 5, 10, 15, 20, 25}) { + auto [x, y] = SetupTestData(dim, ny); + auto flat_y = flatten_2d_vector(y); + + std::vector actual_distances(ny, 0.F); faiss::fvec_inner_products_ny( - distances.data(), x.data(), y.data(), dim, nrows); + actual_distances.data(), x.data(), flat_y.data(), dim, ny); - std::vector distances_ref(nrows, 0); - fvec_inner_products_ny_ref( - distances_ref.data(), x.data(), y.data(), dim, nrows); + std::vector true_distances(ny, 0.F); + for (int i = 0; i < ny; ++i) { + for (int k = 0; k < dim; ++k) { + true_distances[i] += x[k] * y[i][k]; + } + } - ASSERT_EQ(distances, distances_ref) - << "Mismatching results for dim = " << dim - << ", nrows = " << nrows; + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_inner_products_ny results for dim = " + << dim << ", ny = " << ny; } } } -TEST(TestFvecL2sqr, distances_L2_squared_y_transposed) { - // ints instead of floats for 100% accuracy +TEST_P(DistancesSIMDTest, L2SqrNYNearest) { std::default_random_engine rng(123); std::uniform_int_distribution uniform(0, 32); + int dim = 128; + int ny = 11; + + auto [x, y] = SetupTestData(dim, ny); + auto flat_y = flatten_2d_vector(y); + + std::vector true_tmp_buffer_distances(ny, 0.F); + for (int i = 0; i < ny; ++i) { + for (int k = 0; k < dim; ++k) { + const float tmp = x[k] - y[i][k]; + true_tmp_buffer_distances[i] += tmp * tmp; + } + } + + size_t true_nearest_idx = 0; + float min_dis = HUGE_VALF; + + for (size_t i = 0; i < ny; i++) { + if (true_tmp_buffer_distances[i] < min_dis) { + min_dis = true_tmp_buffer_distances[i]; + true_nearest_idx = i; + } + } + + std::vector actual_distances(ny); + auto actual_nearest_index = faiss::fvec_L2sqr_ny_nearest( + actual_distances.data(), x.data(), flat_y.data(), dim, ny); + + EXPECT_EQ(actual_nearest_index, true_nearest_idx); +} + +TEST_P(DistancesSIMDTest, multiple_add) { + // modulo 8 results - 16 is to repeat the while loop in the function + for (const auto dim : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { + auto [x, y] = SetupTestData(dim, 1); + const float bf = uniform(rng); + std::vector true_distances(dim); + for (size_t i = 0; i < x.size(); i++) { + true_distances[i] = x[i] + bf * y[0][i]; + } + + std::vector actual_distances(dim); + faiss::fvec_madd( + x.size(), x.data(), bf, y[0].data(), actual_distances.data()); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_madd results for nrows = " << dim; + } +} + +TEST_P(DistancesSIMDTest, manhattan_distance) { + // modulo 8 results - 16 is to repeat the while loop in the function + for (const auto dim : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { + auto [x, y] = SetupTestData(dim, 1); + float true_distance = 0; + for (size_t i = 0; i < x.size(); i++) { + true_distance += std::abs(x[i] - y[0][i]); + } + + auto actual_distances = faiss::fvec_L1(x.data(), y[0].data(), x.size()); + + ASSERT_EQ(actual_distances, true_distance) + << "Mismatching fvec_Linf results for nrows = " << dim; + } +} + +TEST_P(DistancesSIMDTest, add_value) { + for (const auto dim : {1, 2, 5, 10, 15, 20, 25}) { + auto [x, y] = SetupTestData(dim, 1); + const float b = uniform(rng); // value to add + std::vector true_distances(dim); + for (size_t i = 0; i < x.size(); i++) { + true_distances[i] = x[i] + b; + } + + std::vector actual_distances(dim); + faiss::fvec_add(x.size(), x.data(), b, actual_distances.data()); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching array-value fvec_add results for nrows = " + << dim; + } +} + +TEST_P(DistancesSIMDTest, add_array) { + for (const auto dim : {1, 2, 5, 10, 15, 20, 25}) { + auto [x, y] = SetupTestData(dim, 1); + std::vector true_distances(dim); + for (size_t i = 0; i < x.size(); i++) { + true_distances[i] = x[i] + y[0][i]; + } + + std::vector actual_distances(dim); + faiss::fvec_add( + x.size(), x.data(), y[0].data(), actual_distances.data()); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching array-array fvec_add results for nrows = " + << dim; + } +} + +TEST_P(DistancesSIMDTest, distances_L2_squared_y_transposed) { // modulo 8 results - 16 is to repeat the loop in the function int ny = 11; // this value will hit all the codepaths for (const auto d : {1, 2, 3, 4, 5, 6, 7, 8, 16}) { - // initialize inputs - std::vector x(d); + auto [x, y] = SetupTestData(d, ny); float x_sqlen = 0; - for (size_t i = 0; i < x.size(); i++) { - x[i] = uniform(rng); + for (size_t i = 0; i < d; ++i) { x_sqlen += x[i] * x[i]; } - std::vector y(d * ny); + auto flat_y = flatten_2d_vector(y); std::vector y_sqlens(ny, 0); - for (size_t i = 0; i < ny; i++) { - for (size_t j = 0; j < y.size(); j++) { - y[j] = uniform(rng); - y_sqlens[i] += y[j] * y[j]; + for (size_t i = 0; i < ny; ++i) { + for (size_t j = 0; j < d; ++j) { + y_sqlens[i] += flat_y[j] * flat_y[j]; } } // perform function std::vector true_distances(ny, 0); - for (size_t i = 0; i < ny; i++) { + for (size_t i = 0; i < ny; ++i) { float dp = 0; - for (size_t j = 0; j < d; j++) { - dp += x[j] * y[i + j * ny]; + for (size_t j = 0; j < d; ++j) { + dp += x[j] * flat_y[i + j * ny]; } true_distances[i] = x_sqlen + y_sqlens[i] - 2 * dp; } @@ -145,7 +393,7 @@ TEST(TestFvecL2sqr, distances_L2_squared_y_transposed) { faiss::fvec_L2sqr_ny_transposed( distances.data(), x.data(), - y.data(), + flat_y.data(), y_sqlens.data(), d, ny, // no need for special offset to test all lines of code @@ -156,39 +404,34 @@ TEST(TestFvecL2sqr, distances_L2_squared_y_transposed) { } } -TEST(TestFvecL2sqr, nearest_L2_squared_y_transposed) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); - +TEST_P(DistancesSIMDTest, nearest_L2_squared_y_transposed) { // modulo 8 results - 16 is to repeat the loop in the function int ny = 11; // this value will hit all the codepaths - for (const auto d : {1, 2, 3, 4, 5, 6, 7, 8, 16}) { - // initialize inputs - std::vector x(d); - float x_sqlen = 0; - for (size_t i = 0; i < x.size(); i++) { - x[i] = uniform(rng); + for (const auto dim : {1, 2, 3, 4, 5, 6, 7, 8, 16}) { + auto [x, y] = SetupTestData(dim, ny); + float x_sqlen = 0.F; + for (size_t i = 0; i < dim; i++) { x_sqlen += x[i] * x[i]; } - std::vector y(d * ny); + + auto flat_y = flatten_2d_vector(y); std::vector y_sqlens(ny, 0); + for (size_t i = 0; i < ny; i++) { - for (size_t j = 0; j < y.size(); j++) { - y[j] = uniform(rng); - y_sqlens[i] += y[j] * y[j]; + for (size_t j = 0; j < dim; j++) { + y_sqlens[i] += y[i][j] * y[i][j]; } } - // get distances std::vector distances(ny, 0); for (size_t i = 0; i < ny; i++) { float dp = 0; - for (size_t j = 0; j < d; j++) { - dp += x[j] * y[i + j * ny]; + for (size_t j = 0; j < dim; j++) { + dp += x[j] * flat_y[i + j * ny]; } distances[i] = x_sqlen + y_sqlens[i] - 2 * dp; } + // find nearest size_t true_nearest_idx = 0; float min_dis = HUGE_VALF; @@ -200,135 +443,42 @@ TEST(TestFvecL2sqr, nearest_L2_squared_y_transposed) { } std::vector buffer(ny); - size_t nearest_idx = faiss::fvec_L2sqr_ny_nearest_y_transposed( + size_t actual_nearest_idx = faiss::fvec_L2sqr_ny_nearest_y_transposed( buffer.data(), x.data(), - y.data(), + flat_y.data(), y_sqlens.data(), - d, + dim, ny, // no need for special offset to test all lines of code ny); - ASSERT_EQ(nearest_idx, true_nearest_idx) + ASSERT_EQ(actual_nearest_idx, true_nearest_idx) << "Mismatching fvec_L2sqr_ny_nearest_y_transposed results for d = " - << d; + << dim; } } -TEST(TestFvecL1, manhattan_distance) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); +std::vector GetSupportedSIMDLevels() { + std::vector supported_levels = {faiss::SIMDLevel::NONE}; - // modulo 8 results - 16 is to repeat the while loop in the function - for (const auto nrows : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { - std::vector x(nrows); - std::vector y(nrows); - float true_distance = 0; - for (size_t i = 0; i < x.size(); i++) { - x[i] = uniform(rng); - y[i] = uniform(rng); - true_distance += std::abs(x[i] - y[i]); + for (int level = static_cast(faiss::SIMDLevel::NONE) + 1; + level < static_cast(faiss::SIMDLevel::COUNT); + level++) { + faiss::SIMDLevel simd_level = static_cast(level); + if (faiss::SIMDConfig::is_simd_level_available(simd_level)) { + supported_levels.push_back(simd_level); } - - auto distance = faiss::fvec_L1(x.data(), y.data(), x.size()); - - ASSERT_EQ(distance, true_distance) - << "Mismatching fvec_L1 results for nrows = " << nrows; } -} -TEST(TestFvecLinf, chebyshev_distance) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); + EXPECT_TRUE(supported_levels.size() > 0); - // modulo 8 results - 16 is to repeat the while loop in the function - for (const auto nrows : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { - std::vector x(nrows); - std::vector y(nrows); - float true_distance = 0; - for (size_t i = 0; i < x.size(); i++) { - x[i] = uniform(rng); - y[i] = uniform(rng); - true_distance = std::max(true_distance, std::abs(x[i] - y[i])); - } - - auto distance = faiss::fvec_Linf(x.data(), y.data(), x.size()); - - ASSERT_EQ(distance, true_distance) - << "Mismatching fvec_Linf results for nrows = " << nrows; - } + return std::vector( + supported_levels.begin(), supported_levels.end()); } -TEST(TestFvecMadd, multiple_add) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); - - // modulo 8 results - 16 is to repeat the while loop in the function - for (const auto nrows : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { - std::vector a(nrows); - std::vector b(nrows); - const float bf = uniform(rng); - std::vector true_distances(nrows); - for (size_t i = 0; i < a.size(); i++) { - a[i] = uniform(rng); - b[i] = uniform(rng); - true_distances[i] = a[i] + bf * b[i]; - } - - std::vector distances(nrows); - faiss::fvec_madd(a.size(), a.data(), bf, b.data(), distances.data()); - - ASSERT_EQ(distances, true_distances) - << "Mismatching fvec_madd results for nrows = " << nrows; - } +::testing::internal::ParamGenerator SupportedSIMDLevels() { + std::vector levels = GetSupportedSIMDLevels(); + return ::testing::ValuesIn(levels); } -TEST(TestFvecAdd, add_array) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); - - for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) { - std::vector a(nrows); - std::vector b(nrows); - std::vector true_distances(nrows); - for (size_t i = 0; i < a.size(); i++) { - a[i] = uniform(rng); - b[i] = uniform(rng); - true_distances[i] = a[i] + b[i]; - } - - std::vector distances(nrows); - faiss::fvec_add(a.size(), a.data(), b.data(), distances.data()); - - ASSERT_EQ(distances, true_distances) - << "Mismatching array-array fvec_add results for nrows = " - << nrows; - } -} - -TEST(TestFvecAdd, add_value) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); - - for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) { - std::vector a(nrows); - const float b = uniform(rng); // value to add - std::vector true_distances(nrows); - for (size_t i = 0; i < a.size(); i++) { - a[i] = uniform(rng); - true_distances[i] = a[i] + b; - } - - std::vector distances(nrows); - faiss::fvec_add(a.size(), a.data(), b, distances.data()); - - ASSERT_EQ(distances, true_distances) - << "Mismatching array-value fvec_add results for nrows = " - << nrows; - } -} +INSTANTIATE_TEST_SUITE_P(SIMDLevels, DistancesSIMDTest, SupportedSIMDLevels()); diff --git a/tests/test_simd_levels.cpp b/tests/test_simd_levels.cpp index 4dac2e9877..64da6e77b9 100644 --- a/tests/test_simd_levels.cpp +++ b/tests/test_simd_levels.cpp @@ -6,8 +6,6 @@ */ #include -#include -#include #ifdef __x86_64__ #include @@ -15,25 +13,9 @@ #include -static jmp_buf jmpbuf; -static void sigill_handler(int sig) { - longjmp(jmpbuf, 1); -} - -bool try_execute(void (*func)()) { - signal(SIGILL, sigill_handler); - if (setjmp(jmpbuf) == 0) { - func(); - signal(SIGILL, SIG_DFL); - return true; - } else { - signal(SIGILL, SIG_DFL); - return false; - } -} - #ifdef __x86_64__ -std::vector run_avx2_computation() { +bool run_avx2_computation() { +#if defined(__AVX2__) alignas(32) int result[8]; alignas(32) int input1[8] = {1, 2, 3, 4, 5, 6, 7, 8}; alignas(32) int input2[8] = {8, 7, 6, 5, 4, 3, 2, 1}; @@ -43,10 +25,14 @@ std::vector run_avx2_computation() { __m256i vec_result = _mm256_add_epi32(vec1, vec2); _mm256_store_si256(reinterpret_cast<__m256i*>(result), vec_result); - return {result, result + 8}; + return true; +#else + return false; +#endif // __AVX2__ } -std::vector run_avx512f_computation() { +bool run_avx512f_computation() { +#ifdef __AVX512F__ alignas(64) long long result[8]; alignas(64) long long input1[8] = {1, 2, 3, 4, 5, 6, 7, 8}; alignas(64) long long input2[8] = {8, 7, 6, 5, 4, 3, 2, 1}; @@ -56,11 +42,15 @@ std::vector run_avx512f_computation() { __m512i vec_result = _mm512_add_epi64(vec1, vec2); _mm512_store_si512(reinterpret_cast<__m512i*>(result), vec_result); - return {result, result + 8}; + return true; +#else + return false; +#endif // __AVX512F__ } -std::vector run_avx512cd_computation() { - run_avx512f_computation(); +bool run_avx512cd_computation() { + EXPECT_TRUE(run_avx512f_computation()); +#ifdef __AVX512CD__ __m512i indices = _mm512_set_epi32( 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); @@ -68,38 +58,47 @@ std::vector run_avx512cd_computation() { alignas(64) int mask_array[16]; _mm512_store_epi32(mask_array, conflict_mask); - - return std::vector(); + return true; +#else + return false; +#endif // __AVX512CD__ } -std::vector run_avx512vl_computation() { - run_avx512f_computation(); +bool run_avx512vl_computation() { + EXPECT_TRUE(run_avx512f_computation()); +#ifdef __AVX512VL__ __m256i vec1 = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); __m256i vec2 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); __m256i result = _mm256_add_epi32(vec1, vec2); alignas(32) int result_array[8]; _mm256_store_si256(reinterpret_cast<__m256i*>(result_array), result); - - return std::vector(result_array, result_array + 8); + return true; +#else + return false; +#endif // __AVX512VL__ } -std::vector run_avx512dq_computation() { - run_avx512f_computation(); +bool run_avx512dq_computation() { + EXPECT_TRUE(run_avx512f_computation()); +#ifdef __AVX512DQ__ __m512i vec1 = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0); __m512i vec2 = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); __m512i result = _mm512_add_epi64(vec1, vec2); alignas(64) long long result_array[8]; _mm512_store_si512(result_array, result); - - return std::vector(result_array, result_array + 8); + return true; +#else + return false; +#endif // __AVX512DQ__ } -std::vector run_avx512bw_computation() { - run_avx512f_computation(); +bool run_avx512bw_computation() { + EXPECT_TRUE(run_avx512f_computation()); +#ifdef __AVX512BW__ std::vector input1(64, 0); __m512i vec1 = _mm512_loadu_si512(reinterpret_cast(input1.data())); @@ -111,22 +110,13 @@ std::vector run_avx512bw_computation() { alignas(64) int8_t result_array[64]; _mm512_storeu_si512(reinterpret_cast<__m512i*>(result_array), result); - return std::vector(result_array, result_array + 64); + return true; +#else + return false; +#endif // __AVX512BW__ } #endif // __x86_64__ -std::pair> try_execute(std::vector (*func)()) { - signal(SIGILL, sigill_handler); - if (setjmp(jmpbuf) == 0) { - auto result = func(); - signal(SIGILL, SIG_DFL); - return std::make_pair(true, result); - } else { - signal(SIGILL, SIG_DFL); - return std::make_pair(false, std::vector()); - } -} - TEST(SIMDConfig, simd_level_auto_detect_architecture_only) { faiss::SIMDLevel detected_level = faiss::SIMDConfig::auto_detect_simd_level(); @@ -140,10 +130,12 @@ TEST(SIMDConfig, simd_level_auto_detect_architecture_only) { detected_level == faiss::SIMDLevel::AVX2 || detected_level == faiss::SIMDLevel::AVX512); #elif defined(__aarch64__) && defined(__ARM_NEON) - EXPECT_TRUE(detected_level == faiss::SIMDLevel::ARM_NEON); + // Uncomment following line when dynamic dispatch is enabled for ARM_NEON + // EXPECT_TRUE(detected_level == faiss::SIMDLevel::ARM_NEON); #else EXPECT_EQ(detected_level, faiss::SIMDLevel::NONE); #endif + EXPECT_TRUE(detected_level != faiss::SIMDLevel::COUNT); } #ifdef __x86_64__ @@ -151,10 +143,8 @@ TEST(SIMDConfig, successful_avx2_execution_on_x86arch) { faiss::SIMDConfig simd_config(nullptr); if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX2)) { - auto actual_result = try_execute(run_avx2_computation); - EXPECT_TRUE(actual_result.first); - auto expected_result_vector = std::vector(8, 9); - EXPECT_EQ(actual_result.second, expected_result_vector); + auto actual_result = run_avx2_computation(); + EXPECT_TRUE(actual_result); } } @@ -171,10 +161,8 @@ TEST(SIMDConfig, successful_avx512f_execution_on_x86arch) { faiss::SIMDConfig simd_config(nullptr); if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { - auto actual_result = try_execute(run_avx512f_computation); - EXPECT_TRUE(actual_result.first); - auto expected_result_vector = std::vector(8, 9); - EXPECT_EQ(actual_result.second, expected_result_vector); + auto actual_result = run_avx512f_computation(); + EXPECT_TRUE(actual_result); } } @@ -182,8 +170,8 @@ TEST(SIMDConfig, successful_avx512cd_execution_on_x86arch) { faiss::SIMDConfig simd_config(nullptr); if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { - auto actual = try_execute(run_avx512cd_computation); - EXPECT_TRUE(actual.first); + auto actual = run_avx512cd_computation(); + EXPECT_TRUE(actual); } } @@ -191,9 +179,8 @@ TEST(SIMDConfig, successful_avx512vl_execution_on_x86arch) { faiss::SIMDConfig simd_config(nullptr); if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { - auto actual = try_execute(run_avx512vl_computation); - EXPECT_TRUE(actual.first); - EXPECT_EQ(actual.second, std::vector(8, 7)); + auto actual = run_avx512vl_computation(); + EXPECT_TRUE(actual); } } @@ -203,9 +190,8 @@ TEST(SIMDConfig, successful_avx512dq_execution_on_x86arch) { if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { EXPECT_TRUE( simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)); - auto actual = try_execute(run_avx512dq_computation); - EXPECT_TRUE(actual.first); - EXPECT_EQ(actual.second, std::vector(8, 7)); + auto actual = run_avx512dq_computation(); + EXPECT_TRUE(actual); } } @@ -215,21 +201,22 @@ TEST(SIMDConfig, successful_avx512bw_execution_on_x86arch) { if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { EXPECT_TRUE( simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)); - auto actual = try_execute(run_avx512bw_computation); - EXPECT_TRUE(actual.first); - EXPECT_EQ(actual.second, std::vector(64, 7)); + auto actual = run_avx512bw_computation(); + EXPECT_TRUE(actual); + // EXPECT_TRUE(actual.first); + // EXPECT_EQ(actual.second, std::vector(64, 7)); } } #endif // __x86_64__ TEST(SIMDConfig, override_simd_level) { - const char* faiss_env_var_neon = "ARM_NEON"; - faiss::SIMDConfig simd_neon_config(&faiss_env_var_neon); - EXPECT_EQ(simd_neon_config.level, faiss::SIMDLevel::ARM_NEON); + // const char* faiss_env_var_neon = "ARM_NEON"; + // faiss::SIMDConfig simd_neon_config(&faiss_env_var_neon); + // EXPECT_EQ(simd_neon_config.level, faiss::SIMDLevel::ARM_NEON); - EXPECT_EQ(simd_neon_config.supported_simd_levels().size(), 2); - EXPECT_TRUE(simd_neon_config.is_simd_level_available( - faiss::SIMDLevel::ARM_NEON)); + // EXPECT_EQ(simd_neon_config.supported_simd_levels().size(), 2); + // EXPECT_TRUE(simd_neon_config.is_simd_level_available( + // faiss::SIMDLevel::ARM_NEON)); const char* faiss_env_var_avx512 = "AVX512"; faiss::SIMDConfig simd_avx512_config(&faiss_env_var_avx512); @@ -240,12 +227,12 @@ TEST(SIMDConfig, override_simd_level) { } TEST(SIMDConfig, simd_config_get_level_name) { - const char* faiss_env_var_neon = "ARM_NEON"; - faiss::SIMDConfig simd_neon_config(&faiss_env_var_neon); - EXPECT_EQ(simd_neon_config.level, faiss::SIMDLevel::ARM_NEON); - EXPECT_TRUE(simd_neon_config.is_simd_level_available( - faiss::SIMDLevel::ARM_NEON)); - EXPECT_EQ(faiss_env_var_neon, simd_neon_config.get_level_name()); + // const char* faiss_env_var_neon = "ARM_NEON"; + // faiss::SIMDConfig simd_neon_config(&faiss_env_var_neon); + // EXPECT_EQ(simd_neon_config.level, faiss::SIMDLevel::ARM_NEON); + // EXPECT_TRUE(simd_neon_config.is_simd_level_available( + // faiss::SIMDLevel::ARM_NEON)); + // EXPECT_EQ(faiss_env_var_neon, simd_neon_config.get_level_name()); const char* faiss_env_var_avx512 = "AVX512"; faiss::SIMDConfig simd_avx512_config(&faiss_env_var_avx512); @@ -259,7 +246,8 @@ TEST(SIMDLevel, get_level_name_from_enum) { EXPECT_EQ("NONE", to_string(faiss::SIMDLevel::NONE).value_or("")); EXPECT_EQ("AVX2", to_string(faiss::SIMDLevel::AVX2).value_or("")); EXPECT_EQ("AVX512", to_string(faiss::SIMDLevel::AVX512).value_or("")); - EXPECT_EQ("ARM_NEON", to_string(faiss::SIMDLevel::ARM_NEON).value_or("")); + // EXPECT_EQ("ARM_NEON", + // to_string(faiss::SIMDLevel::ARM_NEON).value_or("")); int actual_num_simd_levels = static_cast(faiss::SIMDLevel::COUNT); EXPECT_EQ(4, actual_num_simd_levels); @@ -275,6 +263,6 @@ TEST(SIMDLevel, to_simd_level_from_string) { EXPECT_EQ(faiss::SIMDLevel::NONE, faiss::to_simd_level("NONE")); EXPECT_EQ(faiss::SIMDLevel::AVX2, faiss::to_simd_level("AVX2")); EXPECT_EQ(faiss::SIMDLevel::AVX512, faiss::to_simd_level("AVX512")); - EXPECT_EQ(faiss::SIMDLevel::ARM_NEON, faiss::to_simd_level("ARM_NEON")); + // EXPECT_EQ(faiss::SIMDLevel::ARM_NEON, faiss::to_simd_level("ARM_NEON")); EXPECT_FALSE(faiss::to_simd_level("INVALID").has_value()); } diff --git a/tests/test_simd_perf.cpp b/tests/test_simd_perf.cpp new file mode 100644 index 0000000000..9b126e601d --- /dev/null +++ b/tests/test_simd_perf.cpp @@ -0,0 +1,184 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Performance validation test for SIMD Dynamic Dispatch. + * + * This test verifies that the DD infrastructure is correctly dispatching + * to optimized SIMD implementations by checking that: + * 1. AVX2 is faster than NONE (scalar) implementation + * 2. The difference is significant (at least 1.5x for typical dimensions) + */ + +#include + +#include +#include +#include +#include + +#include +#include + +class SIMDPerfTest : public ::testing::Test { + protected: + void SetUp() override { + original_level = faiss::SIMDConfig::get_level(); + + // Generate random test data + std::mt19937 rng(42); + std::uniform_real_distribution dist(0.0f, 1.0f); + + x.resize(d); + y.resize(n * d); + c.resize(n * d); // output buffer for fvec_madd + + for (size_t i = 0; i < d; i++) { + x[i] = dist(rng); + } + for (size_t i = 0; i < n * d; i++) { + y[i] = dist(rng); + } + } + + void TearDown() override { + faiss::SIMDConfig::set_level(original_level); + } + + // fvec_L2sqr uses auto-vectorization (same source, different compiler flags) + double benchmark_fvec_L2sqr(faiss::SIMDLevel level, int n_runs = 100) { + faiss::SIMDConfig::set_level(level); + + // Warmup + for (int i = 0; i < 10; i++) { + for (size_t j = 0; j < n; j++) { + volatile float result = faiss::fvec_L2sqr(x.data(), y.data() + j * this->d, this->d); + (void)result; + } + } + + auto start = std::chrono::high_resolution_clock::now(); + for (int run = 0; run < n_runs; run++) { + for (size_t j = 0; j < n; j++) { + volatile float result = faiss::fvec_L2sqr(x.data(), y.data() + j * this->d, this->d); + (void)result; + } + } + auto end = std::chrono::high_resolution_clock::now(); + + std::chrono::duration elapsed = end - start; + return elapsed.count(); + } + + // fvec_madd has explicit AVX2 intrinsics - better for testing dispatch + double benchmark_fvec_madd(faiss::SIMDLevel level, int n_runs = 100) { + faiss::SIMDConfig::set_level(level); + + // Warmup + for (int i = 0; i < 10; i++) { + for (size_t j = 0; j < n; j++) { + faiss::fvec_madd(this->d, x.data(), 0.5f, y.data() + j * this->d, c.data() + j * this->d); + } + } + + auto start = std::chrono::high_resolution_clock::now(); + for (int run = 0; run < n_runs; run++) { + for (size_t j = 0; j < n; j++) { + faiss::fvec_madd(this->d, x.data(), 0.5f, y.data() + j * this->d, c.data() + j * this->d); + } + } + auto end = std::chrono::high_resolution_clock::now(); + + std::chrono::duration elapsed = end - start; + return elapsed.count(); + } + + faiss::SIMDLevel original_level; + size_t d = 128; // dimension + size_t n = 10000; // number of vectors + std::vector x; + std::vector y; + std::vector c; // output buffer +}; + +TEST_F(SIMDPerfTest, AVX2FasterThanNONE) { + // Skip if AVX2 is not available + if (!faiss::SIMDConfig::is_simd_level_available(faiss::SIMDLevel::AVX2)) { + GTEST_SKIP() << "AVX2 not available on this machine"; + } + + // Test fvec_madd which has explicit AVX2 intrinsics + // (fvec_L2sqr uses auto-vectorization so speedup is less predictable) + + // Benchmark NONE + double none_time = benchmark_fvec_madd(faiss::SIMDLevel::NONE); + printf("fvec_madd NONE: %.2f ms\n", none_time); + + // Benchmark AVX2 + double avx2_time = benchmark_fvec_madd(faiss::SIMDLevel::AVX2); + printf("fvec_madd AVX2: %.2f ms\n", avx2_time); + + // AVX2 should be faster than NONE + double speedup = none_time / avx2_time; + printf("fvec_madd Speedup: %.2fx\n", speedup); + + // We expect at least 1.5x speedup with AVX2 for fvec_madd (explicit intrinsics) + // The actual speedup can vary based on CPU, but should be significant + EXPECT_GT(speedup, 1.5) + << "AVX2 should be significantly faster than NONE for fvec_madd. " + << "NONE=" << none_time << "ms, AVX2=" << avx2_time << "ms"; +} + +TEST_F(SIMDPerfTest, AVX512FasterThanAVX2IfAvailable) { + // Skip if AVX512 is not available + if (!faiss::SIMDConfig::is_simd_level_available(faiss::SIMDLevel::AVX512)) { + GTEST_SKIP() << "AVX512 not available on this machine"; + } + + // Benchmark AVX2 + double avx2_time = benchmark_fvec_madd(faiss::SIMDLevel::AVX2); + printf("fvec_madd AVX2: %.2f ms\n", avx2_time); + + // Benchmark AVX512 + double avx512_time = benchmark_fvec_madd(faiss::SIMDLevel::AVX512); + printf("fvec_madd AVX512: %.2f ms\n", avx512_time); + + double ratio = avx512_time / avx2_time; + printf("Ratio (AVX512/AVX2): %.2f\n", ratio); + + // AVX512 should not be significantly slower than AVX2 (allow 25% margin + // for frequency throttling) + EXPECT_LT(ratio, 1.25) + << "AVX512 should not be more than 25% slower than AVX2. " + << "AVX2=" << avx2_time << "ms, AVX512=" << avx512_time << "ms"; +} + +// Additional test: Verify fvec_L2sqr dispatch is at least not slower +// fvec_L2sqr uses auto-vectorization, so AVX2 may only be slightly faster +TEST_F(SIMDPerfTest, L2sqrAutoVecDispatchWorks) { + // Skip if AVX2 is not available + if (!faiss::SIMDConfig::is_simd_level_available(faiss::SIMDLevel::AVX2)) { + GTEST_SKIP() << "AVX2 not available on this machine"; + } + + // Benchmark NONE + double none_time = benchmark_fvec_L2sqr(faiss::SIMDLevel::NONE); + printf("fvec_L2sqr NONE (SSE4 autovec): %.2f ms\n", none_time); + + // Benchmark AVX2 + double avx2_time = benchmark_fvec_L2sqr(faiss::SIMDLevel::AVX2); + printf("fvec_L2sqr AVX2 (AVX2 autovec): %.2f ms\n", avx2_time); + + double speedup = none_time / avx2_time; + printf("fvec_L2sqr Speedup: %.2fx\n", speedup); + + // Auto-vectorization may not show huge gains, but should not be slower + // Allow some variance (0.9x) for measurement noise + EXPECT_GT(speedup, 0.9) + << "AVX2 auto-vectorized code should not be slower than SSE4. " + << "NONE=" << none_time << "ms, AVX2=" << avx2_time << "ms"; +} From ae50cf89687c99d8aa4870ee75e270c429e812f8 Mon Sep 17 00:00:00 2001 From: Gergely Szilvasy Date: Fri, 23 Jan 2026 03:31:20 -0800 Subject: [PATCH 3/3] moved IndexIVFPQ and IndexPQ to dynamic dispatch (#4555) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4555 Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4291 moved IndexIVFPQ and IndexPQ to dynamic dispatch. Since the code was already quite modular (thanks Alex!), this boils down to make independent cpp files for the different SIMD versions. Reviewed By: mnorris11 Differential Revision: D72937709 --- faiss/IndexIVFPQ.cpp | 68 ++- faiss/IndexPQ.cpp | 32 +- .../impl/code_distance/code_distance-avx2.cpp | 490 ++++++++++++++++ faiss/impl/code_distance/code_distance-avx2.h | 534 ------------------ .../code_distance/code_distance-avx512.cpp | 203 +++++++ .../impl/code_distance/code_distance-avx512.h | 248 -------- .../code_distance/code_distance-generic.cpp | 20 + .../code_distance/code_distance-generic.h | 81 --- ...e_distance-sve.h => code_distance-sve.cpp} | 4 +- faiss/impl/code_distance/code_distance.h | 229 +++----- faiss/utils/{ => simd_impl}/simdlib_avx2.h | 0 faiss/utils/{ => simd_impl}/simdlib_avx512.h | 0 .../utils/{ => simd_impl}/simdlib_emulated.h | 0 faiss/utils/{ => simd_impl}/simdlib_neon.h | 0 faiss/utils/{ => simd_impl}/simdlib_ppc64.h | 0 faiss/utils/simdlib.h | 8 +- tests/test_code_distance.cpp | 73 ++- 17 files changed, 896 insertions(+), 1094 deletions(-) create mode 100644 faiss/impl/code_distance/code_distance-avx2.cpp delete mode 100644 faiss/impl/code_distance/code_distance-avx2.h create mode 100644 faiss/impl/code_distance/code_distance-avx512.cpp delete mode 100644 faiss/impl/code_distance/code_distance-avx512.h create mode 100644 faiss/impl/code_distance/code_distance-generic.cpp delete mode 100644 faiss/impl/code_distance/code_distance-generic.h rename faiss/impl/code_distance/{code_distance-sve.h => code_distance-sve.cpp} (99%) rename faiss/utils/{ => simd_impl}/simdlib_avx2.h (100%) rename faiss/utils/{ => simd_impl}/simdlib_avx512.h (100%) rename faiss/utils/{ => simd_impl}/simdlib_emulated.h (100%) rename faiss/utils/{ => simd_impl}/simdlib_neon.h (100%) rename faiss/utils/{ => simd_impl}/simdlib_ppc64.h (100%) diff --git a/faiss/IndexIVFPQ.cpp b/faiss/IndexIVFPQ.cpp index f4ae7a177c..aaf9ccd0f5 100644 --- a/faiss/IndexIVFPQ.cpp +++ b/faiss/IndexIVFPQ.cpp @@ -794,8 +794,9 @@ struct WrappedSearchResult { * The scanning functions call their favorite precompute_* * function to precompute the tables they need. *****************************************************/ -template +template struct IVFPQScannerT : QueryTables { + using PQDecoder = typename PQCodeDistance::PQDecoder; const uint8_t* list_codes; const IDType* list_ids; size_t list_size; @@ -871,7 +872,7 @@ struct IVFPQScannerT : QueryTables { float distance_1 = 0; float distance_2 = 0; float distance_3 = 0; - distance_four_codes( + PQCodeDistance::distance_four_codes( pq.M, pq.nbits, sim_table, @@ -894,7 +895,7 @@ struct IVFPQScannerT : QueryTables { if (counter >= 1) { float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -903,7 +904,7 @@ struct IVFPQScannerT : QueryTables { } if (counter >= 2) { float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -912,7 +913,7 @@ struct IVFPQScannerT : QueryTables { } if (counter >= 3) { float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -1078,7 +1079,7 @@ struct IVFPQScannerT : QueryTables { float distance_1 = dis0; float distance_2 = dis0; float distance_3 = dis0; - distance_four_codes( + PQCodeDistance::distance_four_codes( pq.M, pq.nbits, sim_table, @@ -1109,7 +1110,7 @@ struct IVFPQScannerT : QueryTables { n_hamming_pass++; float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -1129,7 +1130,7 @@ struct IVFPQScannerT : QueryTables { n_hamming_pass++; float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -1176,8 +1177,8 @@ struct IVFPQScannerT : QueryTables { * * use_sel: store or ignore the IDSelector */ -template -struct IVFPQScanner : IVFPQScannerT, +template +struct IVFPQScanner : IVFPQScannerT, InvertedListScanner { int precompute_mode; const IDSelector* sel; @@ -1187,7 +1188,7 @@ struct IVFPQScanner : IVFPQScannerT, bool store_pairs, int precompute_mode, const IDSelector* sel) - : IVFPQScannerT(ivfpq, nullptr), + : IVFPQScannerT(ivfpq, nullptr), precompute_mode(precompute_mode), sel(sel) { this->store_pairs = store_pairs; @@ -1207,7 +1208,7 @@ struct IVFPQScanner : IVFPQScannerT, float distance_to_code(const uint8_t* code) const override { assert(precompute_mode == 2); float dis = this->dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( this->pq.M, this->pq.nbits, this->sim_table, code); return dis; } @@ -1239,7 +1240,9 @@ struct IVFPQScanner : IVFPQScannerT, } }; -template +/** follow 3 stages of template dispatching */ + +template InvertedListScanner* get_InvertedListScanner1( const IndexIVFPQ& index, bool store_pairs, @@ -1248,32 +1251,47 @@ InvertedListScanner* get_InvertedListScanner1( return new IVFPQScanner< METRIC_INNER_PRODUCT, CMin, - PQDecoder, + PQCodeDistance, use_sel>(index, store_pairs, 2, sel); } else if (index.metric_type == METRIC_L2) { return new IVFPQScanner< METRIC_L2, CMax, - PQDecoder, + PQCodeDistance, use_sel>(index, store_pairs, 2, sel); } return nullptr; } -template +template InvertedListScanner* get_InvertedListScanner2( const IndexIVFPQ& index, bool store_pairs, const IDSelector* sel) { if (index.pq.nbits == 8) { - return get_InvertedListScanner1( - index, store_pairs, sel); + return get_InvertedListScanner1< + PQCodeDistance, + use_sel>(index, store_pairs, sel); } else if (index.pq.nbits == 16) { - return get_InvertedListScanner1( - index, store_pairs, sel); + return get_InvertedListScanner1< + PQCodeDistance, + use_sel>(index, store_pairs, sel); + } else { + return get_InvertedListScanner1< + PQCodeDistance, + use_sel>(index, store_pairs, sel); + } +} + +template +InvertedListScanner* get_InvertedListScanner3( + const IndexIVFPQ& index, + bool store_pairs, + const IDSelector* sel) { + if (sel) { + return get_InvertedListScanner2(index, store_pairs, sel); } else { - return get_InvertedListScanner1( - index, store_pairs, sel); + return get_InvertedListScanner2(index, store_pairs, sel); } } @@ -1283,11 +1301,7 @@ InvertedListScanner* IndexIVFPQ::get_InvertedListScanner( bool store_pairs, const IDSelector* sel, const IVFSearchParameters*) const { - if (sel) { - return get_InvertedListScanner2(*this, store_pairs, sel); - } else { - return get_InvertedListScanner2(*this, store_pairs, sel); - } + DISPATCH_SIMDLevel(get_InvertedListScanner3, *this, store_pairs, sel); return nullptr; } diff --git a/faiss/IndexPQ.cpp b/faiss/IndexPQ.cpp index 255900ced6..a50fa2738e 100644 --- a/faiss/IndexPQ.cpp +++ b/faiss/IndexPQ.cpp @@ -72,7 +72,7 @@ void IndexPQ::train(idx_t n, const float* x) { namespace { -template +template struct PQDistanceComputer : FlatCodesDistanceComputer { size_t d; MetricType metric; @@ -86,7 +86,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer { float distance_to_code(const uint8_t* code) final { ndis++; - float dis = distance_single_code( + float dis = PQCodeDistance::distance_single_code( pq.M, pq.nbits, precomputed_table.data(), code); return dis; } @@ -95,8 +95,10 @@ struct PQDistanceComputer : FlatCodesDistanceComputer { FAISS_THROW_IF_NOT(sdc); const float* sdci = sdc; float accu = 0; - PQDecoder codei(codes + i * code_size, pq.nbits); - PQDecoder codej(codes + j * code_size, pq.nbits); + typename PQCodeDistance::PQDecoder codei( + codes + i * code_size, pq.nbits); + typename PQCodeDistance::PQDecoder codej( + codes + j * code_size, pq.nbits); for (int l = 0; l < pq.M; l++) { accu += sdci[codei.decode() + (codej.decode() << codei.nbits)]; @@ -134,16 +136,24 @@ struct PQDistanceComputer : FlatCodesDistanceComputer { } }; +template +FlatCodesDistanceComputer* get_FlatCodesDistanceComputer1( + const IndexPQ& index) { + int nbits = index.pq.nbits; + if (nbits == 8) { + return new PQDistanceComputer>(index); + } else if (nbits == 16) { + return new PQDistanceComputer>(index); + } else { + return new PQDistanceComputer>( + index); + } +} + } // namespace FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const { - if (pq.nbits == 8) { - return new PQDistanceComputer(*this); - } else if (pq.nbits == 16) { - return new PQDistanceComputer(*this); - } else { - return new PQDistanceComputer(*this); - } + DISPATCH_SIMDLevel(get_FlatCodesDistanceComputer1, *this); } /***************************************** diff --git a/faiss/impl/code_distance/code_distance-avx2.cpp b/faiss/impl/code_distance/code_distance-avx2.cpp new file mode 100644 index 0000000000..e1e12daca2 --- /dev/null +++ b/faiss/impl/code_distance/code_distance-avx2.cpp @@ -0,0 +1,490 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef COMPILE_SIMD_AVX2 + +#include + +#include +#include + +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=78782 +#if defined(__GNUC__) && __GNUC__ < 9 +#define _mm_loadu_si64(x) (_mm_loadl_epi64((__m128i_u*)x)) +#endif + +namespace { + +inline float horizontal_sum(const __m128 v) { + const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); + const __m128 v1 = _mm_add_ps(v, v0); + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + const __m128 v3 = _mm_add_ps(v1, v2); + return _mm_cvtss_f32(v3); +} + +// Computes a horizontal sum over an __m256 register +inline float horizontal_sum(const __m256 v) { + const __m128 v0 = + _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + return horizontal_sum(v0); +} + +// processes a single code for M=4, ksub=256, nbits=8 +float inline distance_single_code_avx2_pqdecoder8_m4( + // precomputed distances, layout (4, 256) + const float* sim_table, + const uint8_t* code) { + float result = 0; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + const __m128i vksub = _mm_set1_epi32(ksub); + __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); + offsets_0 = _mm_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m128 partialSum; + + // load 4 uint8 values + const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code)); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m128i idx1 = _mm_cvtepu8_epi32(mm1); + + // add offsets + const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m128 collected = + _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSum = collected; + } + + // horizontal sum for partialSum + result = horizontal_sum(partialSum); + return result; +} + +// processes a single code for M=8, ksub=256, nbits=8 +float inline distance_single_code_avx2_pqdecoder8_m8( + // precomputed distances, layout (8, 256) + const float* sim_table, + const uint8_t* code) { + float result = 0; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSum; + + // load 8 uint8 values + const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); + + // add offsets + const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = + _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSum = collected; + } + + // horizontal sum for partialSum + result = horizontal_sum(partialSum); + return result; +} + +// processes four codes for M=4, ksub=256, nbits=8 +inline void distance_four_codes_avx2_pqdecoder8_m4( + // precomputed distances, layout (4, 256) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + constexpr intptr_t N = 4; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + // process 8 values + const __m128i vksub = _mm_set1_epi32(ksub); + __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); + offsets_0 = _mm_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m128 partialSums[N]; + + // load 4 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0)); + mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1)); + mm1[2] = _mm_cvtsi32_si128(*((const int32_t*)code2)); + mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3)); + + for (intptr_t j = 0; j < N; j++) { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); + + // gather 4 values, similar to 4 operations of tab[idx] + __m128 collected = + _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = collected; + } + + // horizontal sum for partialSum + result0 = horizontal_sum(partialSums[0]); + result1 = horizontal_sum(partialSums[1]); + result2 = horizontal_sum(partialSums[2]); + result3 = horizontal_sum(partialSums[3]); +} + +// processes four codes for M=8, ksub=256, nbits=8 +inline void distance_four_codes_avx2_pqdecoder8_m8( + // precomputed distances, layout (8, 256) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + constexpr intptr_t N = 4; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + // process 8 values + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSums[N]; + + // load 8 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si64((const __m128i_u*)code0); + mm1[1] = _mm_loadu_si64((const __m128i_u*)code1); + mm1[2] = _mm_loadu_si64((const __m128i_u*)code2); + mm1[3] = _mm_loadu_si64((const __m128i_u*)code3); + + for (intptr_t j = 0; j < N; j++) { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = + _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = collected; + } + + // horizontal sum for partialSum + result0 = horizontal_sum(partialSums[0]); + result1 = horizontal_sum(partialSums[1]); + result2 = horizontal_sum(partialSums[2]); + result3 = horizontal_sum(partialSums[3]); +} + +} // namespace + +namespace faiss { + +template <> +struct PQCodeDistance { + float distance_single_code( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + const uint8_t* code) { + if (M == 4) { + return distance_single_code_avx2_pqdecoder8_m4(sim_table, code); + } + if (M == 8) { + return distance_single_code_avx2_pqdecoder8_m8(sim_table, code); + } + + float result = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSum = _mm256_setzero_ps(); + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + const __m128i mm1 = + _mm_loadu_si128((const __m128i_u*)(code + m)); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + tab += ksub * 8; + + // collect partial sums + partialSum = _mm256_add_ps(partialSum, collected); + } + + // move high 8 uint8 to low ones + const __m128i mm2 = + _mm_unpackhi_epi64(mm1, _mm_setzero_si128()); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + tab += ksub * 8; + + // collect partial sums + partialSum = _mm256_add_ps(partialSum, collected); + } + } + + // horizontal sum for partialSum + result += horizontal_sum(partialSum); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder(code + m, nbits); + + for (; m < M; m++) { + result += tab[decoder.decode()]; + tab += ksub; + } + } + + return result; + } + + // Combines 4 operations of distance_single_code() + void distance_four_codes( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + if (M == 4) { + distance_four_codes_avx2_pqdecoder8_m4( + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); + return; + } + if (M == 8) { + distance_four_codes_avx2_pqdecoder8_m8( + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); + return; + } + + result0 = 0; + result1 = 0; + result2 = 0; + result3 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 4; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm256_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); + mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); + mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm256_add_ps(partialSums[j], collected); + } + tab += ksub * 8; + + // process next 8 codes + for (intptr_t j = 0; j < N; j++) { + // move high 8 uint8 to low ones + const __m128i mm2 = + _mm_unpackhi_epi64(mm1[j], _mm_setzero_si128()); + + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm256_add_ps(partialSums[j], collected); + } + + tab += ksub * 8; + } + + // horizontal sum for partialSum + result0 += horizontal_sum(partialSums[0]); + result1 += horizontal_sum(partialSums[1]); + result2 += horizontal_sum(partialSums[2]); + result3 += horizontal_sum(partialSums[3]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + PQDecoder8 decoder1(code1 + m, nbits); + PQDecoder8 decoder2(code2 + m, nbits); + PQDecoder8 decoder3(code3 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + result1 += tab[decoder1.decode()]; + result2 += tab[decoder2.decode()]; + result3 += tab[decoder3.decode()]; + tab += ksub; + } + } + } +}; + +// explicit template instanciations +// template struct PQCodeDistance; + +// these two will automatically use the generic implementation +template struct PQCodeDistance; +template struct PQCodeDistance; + +} // namespace faiss + +#endif // COMPILE_SIMD_AVX2 diff --git a/faiss/impl/code_distance/code_distance-avx2.h b/faiss/impl/code_distance/code_distance-avx2.h deleted file mode 100644 index 53380b6e46..0000000000 --- a/faiss/impl/code_distance/code_distance-avx2.h +++ /dev/null @@ -1,534 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#ifdef __AVX2__ - -#include - -#include - -#include -#include - -// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=78782 -#if defined(__GNUC__) && __GNUC__ < 9 -#define _mm_loadu_si64(x) (_mm_loadl_epi64((__m128i_u*)x)) -#endif - -namespace { - -inline float horizontal_sum(const __m128 v) { - const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); - const __m128 v1 = _mm_add_ps(v, v0); - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - const __m128 v3 = _mm_add_ps(v1, v2); - return _mm_cvtss_f32(v3); -} - -// Computes a horizontal sum over an __m256 register -inline float horizontal_sum(const __m256 v) { - const __m128 v0 = - _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); - return horizontal_sum(v0); -} - -// processes a single code for M=4, ksub=256, nbits=8 -float inline distance_single_code_avx2_pqdecoder8_m4( - // precomputed distances, layout (4, 256) - const float* sim_table, - const uint8_t* code) { - float result = 0; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - const __m128i vksub = _mm_set1_epi32(ksub); - __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); - offsets_0 = _mm_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m128 partialSum; - - // load 4 uint8 values - const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code)); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m128i idx1 = _mm_cvtepu8_epi32(mm1); - - // add offsets - const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m128 collected = - _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSum = collected; - } - - // horizontal sum for partialSum - result = horizontal_sum(partialSum); - return result; -} - -// processes a single code for M=8, ksub=256, nbits=8 -float inline distance_single_code_avx2_pqdecoder8_m8( - // precomputed distances, layout (8, 256) - const float* sim_table, - const uint8_t* code) { - float result = 0; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSum; - - // load 8 uint8 values - const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); - - // add offsets - const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = - _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSum = collected; - } - - // horizontal sum for partialSum - result = horizontal_sum(partialSum); - return result; -} - -// processes four codes for M=4, ksub=256, nbits=8 -inline void distance_four_codes_avx2_pqdecoder8_m4( - // precomputed distances, layout (4, 256) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - constexpr intptr_t N = 4; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - // process 8 values - const __m128i vksub = _mm_set1_epi32(ksub); - __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); - offsets_0 = _mm_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m128 partialSums[N]; - - // load 4 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0)); - mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1)); - mm1[2] = _mm_cvtsi32_si128(*((const int32_t*)code2)); - mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3)); - - for (intptr_t j = 0; j < N; j++) { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); - - // gather 4 values, similar to 4 operations of tab[idx] - __m128 collected = - _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = collected; - } - - // horizontal sum for partialSum - result0 = horizontal_sum(partialSums[0]); - result1 = horizontal_sum(partialSums[1]); - result2 = horizontal_sum(partialSums[2]); - result3 = horizontal_sum(partialSums[3]); -} - -// processes four codes for M=8, ksub=256, nbits=8 -inline void distance_four_codes_avx2_pqdecoder8_m8( - // precomputed distances, layout (8, 256) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - constexpr intptr_t N = 4; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - // process 8 values - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSums[N]; - - // load 8 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si64((const __m128i_u*)code0); - mm1[1] = _mm_loadu_si64((const __m128i_u*)code1); - mm1[2] = _mm_loadu_si64((const __m128i_u*)code2); - mm1[3] = _mm_loadu_si64((const __m128i_u*)code3); - - for (intptr_t j = 0; j < N; j++) { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = - _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = collected; - } - - // horizontal sum for partialSum - result0 = horizontal_sum(partialSums[0]); - result1 = horizontal_sum(partialSums[1]); - result2 = horizontal_sum(partialSums[2]); - result3 = horizontal_sum(partialSums[3]); -} - -} // namespace - -namespace faiss { - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code) { - // default implementation - return distance_single_code_generic(M, nbits, sim_table, code); -} - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code) { - if (M == 4) { - return distance_single_code_avx2_pqdecoder8_m4(sim_table, code); - } - if (M == 8) { - return distance_single_code_avx2_pqdecoder8_m8(sim_table, code); - } - - float result = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSum = _mm256_setzero_ps(); - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - const __m128i mm1 = _mm_loadu_si128((const __m128i_u*)(code + m)); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - tab += ksub * 8; - - // collect partial sums - partialSum = _mm256_add_ps(partialSum, collected); - } - - // move high 8 uint8 to low ones - const __m128i mm2 = _mm_unpackhi_epi64(mm1, _mm_setzero_si128()); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - tab += ksub * 8; - - // collect partial sums - partialSum = _mm256_add_ps(partialSum, collected); - } - } - - // horizontal sum for partialSum - result += horizontal_sum(partialSum); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder(code + m, nbits); - - for (; m < M; m++) { - result += tab[decoder.decode()]; - tab += ksub; - } - } - - return result; -} - -template -typename std::enable_if::value, void>:: - type - distance_four_codes_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_generic( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -// Combines 4 operations of distance_single_code() -template -typename std::enable_if::value, void>::type -distance_four_codes_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - if (M == 4) { - distance_four_codes_avx2_pqdecoder8_m4( - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); - return; - } - if (M == 8) { - distance_four_codes_avx2_pqdecoder8_m8( - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); - return; - } - - result0 = 0; - result1 = 0; - result2 = 0; - result3 = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - constexpr intptr_t N = 4; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSums[N]; - for (intptr_t j = 0; j < N; j++) { - partialSums[j] = _mm256_setzero_ps(); - } - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); - mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); - mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); - mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); - - // process first 8 codes - for (intptr_t j = 0; j < N; j++) { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm256_add_ps(partialSums[j], collected); - } - tab += ksub * 8; - - // process next 8 codes - for (intptr_t j = 0; j < N; j++) { - // move high 8 uint8 to low ones - const __m128i mm2 = - _mm_unpackhi_epi64(mm1[j], _mm_setzero_si128()); - - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm256_add_ps(partialSums[j], collected); - } - - tab += ksub * 8; - } - - // horizontal sum for partialSum - result0 += horizontal_sum(partialSums[0]); - result1 += horizontal_sum(partialSums[1]); - result2 += horizontal_sum(partialSums[2]); - result3 += horizontal_sum(partialSums[3]); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder0(code0 + m, nbits); - PQDecoder8 decoder1(code1 + m, nbits); - PQDecoder8 decoder2(code2 + m, nbits); - PQDecoder8 decoder3(code3 + m, nbits); - for (; m < M; m++) { - result0 += tab[decoder0.decode()]; - result1 += tab[decoder1.decode()]; - result2 += tab[decoder2.decode()]; - result3 += tab[decoder3.decode()]; - tab += ksub; - } - } -} - -} // namespace faiss - -#endif diff --git a/faiss/impl/code_distance/code_distance-avx512.cpp b/faiss/impl/code_distance/code_distance-avx512.cpp new file mode 100644 index 0000000000..aa16b1c4b8 --- /dev/null +++ b/faiss/impl/code_distance/code_distance-avx512.cpp @@ -0,0 +1,203 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef COMPILE_SIMD_AVX512 + +#include + +#include + +#include +#include + +// According to experiments, the AVX-512 version may be SLOWER than +// the AVX2 version, which is somewhat unexpected. +// This version is not used for now, but it may be used later. +// +// TODO: test for AMD CPUs. + +namespace faiss { + +template <> +struct PQCodeDistance { + float distance_single_code( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + const uint8_t* code0) { + float result0 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 1; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m512i vksub = _mm512_set1_epi32(ksub); + __m512i offsets_0 = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m512 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm512_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m512i indices_to_read_from = + _mm512_add_epi32(idx1, offsets_0); + + // gather 16 values, similar to 16 operations of tab[idx] + __m512 collected = _mm512_i32gather_ps( + indices_to_read_from, tab, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm512_add_ps(partialSums[j], collected); + } + tab += ksub * 16; + } + + // horizontal sum for partialSum + result0 += _mm512_reduce_add_ps(partialSums[0]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + tab += ksub; + } + } + + return result0; + } + + void distance_four_codes_avx512( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + result0 = 0; + result1 = 0; + result2 = 0; + result3 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 4; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m512i vksub = _mm512_set1_epi32(ksub); + __m512i offsets_0 = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m512 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm512_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); + mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); + mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m512i indices_to_read_from = + _mm512_add_epi32(idx1, offsets_0); + + // gather 16 values, similar to 16 operations of tab[idx] + __m512 collected = _mm512_i32gather_ps( + indices_to_read_from, tab, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm512_add_ps(partialSums[j], collected); + } + tab += ksub * 16; + } + + // horizontal sum for partialSum + result0 += _mm512_reduce_add_ps(partialSums[0]); + result1 += _mm512_reduce_add_ps(partialSums[1]); + result2 += _mm512_reduce_add_ps(partialSums[2]); + result3 += _mm512_reduce_add_ps(partialSums[3]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + PQDecoder8 decoder1(code1 + m, nbits); + PQDecoder8 decoder2(code2 + m, nbits); + PQDecoder8 decoder3(code3 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + result1 += tab[decoder1.decode()]; + result2 += tab[decoder2.decode()]; + result3 += tab[decoder3.decode()]; + tab += ksub; + } + } + } +}; + +// explicit template instanciations +// template struct PQCodeDistance; + +// these two will automatically use the generic implementation +template struct PQCodeDistance; +template struct PQCodeDistance; + +} // namespace faiss + +#endif // COMPILE_SIMD_AVX512F diff --git a/faiss/impl/code_distance/code_distance-avx512.h b/faiss/impl/code_distance/code_distance-avx512.h deleted file mode 100644 index d05c41c19c..0000000000 --- a/faiss/impl/code_distance/code_distance-avx512.h +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#ifdef __AVX512F__ - -#include - -#include - -#include -#include - -namespace faiss { - -// According to experiments, the AVX-512 version may be SLOWER than -// the AVX2 version, which is somewhat unexpected. -// This version is not used for now, but it may be used later. -// -// TODO: test for AMD CPUs. - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code) { - // default implementation - return distance_single_code_generic(M, nbits, sim_table, code); -} - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code0) { - float result0 = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - constexpr intptr_t N = 1; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - const __m512i vksub = _mm512_set1_epi32(ksub); - __m512i offsets_0 = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m512 partialSums[N]; - for (intptr_t j = 0; j < N; j++) { - partialSums[j] = _mm512_setzero_ps(); - } - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); - - // process first 8 codes - for (intptr_t j = 0; j < N; j++) { - const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m512i indices_to_read_from = - _mm512_add_epi32(idx1, offsets_0); - - // gather 16 values, similar to 16 operations of tab[idx] - __m512 collected = _mm512_i32gather_ps( - indices_to_read_from, tab, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm512_add_ps(partialSums[j], collected); - } - tab += ksub * 16; - } - - // horizontal sum for partialSum - result0 += _mm512_reduce_add_ps(partialSums[0]); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder0(code0 + m, nbits); - for (; m < M; m++) { - result0 += tab[decoder0.decode()]; - tab += ksub; - } - } - - return result0; -} - -template -typename std::enable_if::value, void>:: - type - distance_four_codes_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_generic( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -// Combines 4 operations of distance_single_code() -template -typename std::enable_if::value, void>::type -distance_four_codes_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - result0 = 0; - result1 = 0; - result2 = 0; - result3 = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - constexpr intptr_t N = 4; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - const __m512i vksub = _mm512_set1_epi32(ksub); - __m512i offsets_0 = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m512 partialSums[N]; - for (intptr_t j = 0; j < N; j++) { - partialSums[j] = _mm512_setzero_ps(); - } - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); - mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); - mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); - mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); - - // process first 8 codes - for (intptr_t j = 0; j < N; j++) { - const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m512i indices_to_read_from = - _mm512_add_epi32(idx1, offsets_0); - - // gather 16 values, similar to 16 operations of tab[idx] - __m512 collected = _mm512_i32gather_ps( - indices_to_read_from, tab, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm512_add_ps(partialSums[j], collected); - } - tab += ksub * 16; - } - - // horizontal sum for partialSum - result0 += _mm512_reduce_add_ps(partialSums[0]); - result1 += _mm512_reduce_add_ps(partialSums[1]); - result2 += _mm512_reduce_add_ps(partialSums[2]); - result3 += _mm512_reduce_add_ps(partialSums[3]); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder0(code0 + m, nbits); - PQDecoder8 decoder1(code1 + m, nbits); - PQDecoder8 decoder2(code2 + m, nbits); - PQDecoder8 decoder3(code3 + m, nbits); - for (; m < M; m++) { - result0 += tab[decoder0.decode()]; - result1 += tab[decoder1.decode()]; - result2 += tab[decoder2.decode()]; - result3 += tab[decoder3.decode()]; - tab += ksub; - } - } -} - -} // namespace faiss - -#endif diff --git a/faiss/impl/code_distance/code_distance-generic.cpp b/faiss/impl/code_distance/code_distance-generic.cpp new file mode 100644 index 0000000000..ac9561ed93 --- /dev/null +++ b/faiss/impl/code_distance/code_distance-generic.cpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace faiss { + +// explicit template instanciations +template struct PQCodeDistance; +template struct PQCodeDistance; +template struct PQCodeDistance; + +} // namespace faiss diff --git a/faiss/impl/code_distance/code_distance-generic.h b/faiss/impl/code_distance/code_distance-generic.h deleted file mode 100644 index c02551c415..0000000000 --- a/faiss/impl/code_distance/code_distance-generic.h +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include -#include - -namespace faiss { - -/// Returns the distance to a single code. -template -inline float distance_single_code_generic( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - PQDecoderT decoder(code, nbits); - const size_t ksub = 1 << nbits; - - const float* tab = sim_table; - float result = 0; - - for (size_t m = 0; m < M; m++) { - result += tab[decoder.decode()]; - tab += ksub; - } - - return result; -} - -/// Combines 4 operations of distance_single_code() -/// General-purpose version. -template -inline void distance_four_codes_generic( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - PQDecoderT decoder0(code0, nbits); - PQDecoderT decoder1(code1, nbits); - PQDecoderT decoder2(code2, nbits); - PQDecoderT decoder3(code3, nbits); - const size_t ksub = 1 << nbits; - - const float* tab = sim_table; - result0 = 0; - result1 = 0; - result2 = 0; - result3 = 0; - - for (size_t m = 0; m < M; m++) { - result0 += tab[decoder0.decode()]; - result1 += tab[decoder1.decode()]; - result2 += tab[decoder2.decode()]; - result3 += tab[decoder3.decode()]; - tab += ksub; - } -} - -} // namespace faiss diff --git a/faiss/impl/code_distance/code_distance-sve.h b/faiss/impl/code_distance/code_distance-sve.cpp similarity index 99% rename from faiss/impl/code_distance/code_distance-sve.h rename to faiss/impl/code_distance/code_distance-sve.cpp index 82f7746be6..9a941798ff 100644 --- a/faiss/impl/code_distance/code_distance-sve.h +++ b/faiss/impl/code_distance/code_distance-sve.cpp @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -#pragma once - #ifdef __ARM_FEATURE_SVE #include @@ -15,7 +13,7 @@ #include #include -#include +#include namespace faiss { diff --git a/faiss/impl/code_distance/code_distance.h b/faiss/impl/code_distance/code_distance.h index 8f29abda97..585890cb40 100644 --- a/faiss/impl/code_distance/code_distance.h +++ b/faiss/impl/code_distance/code_distance.h @@ -9,6 +9,10 @@ #include +#include + +#include + // This directory contains functions to compute a distance // from a given PQ code to a query vector, given that the // distances to a query vector for pq.M codebooks are precomputed. @@ -24,163 +28,76 @@ // why the names of the functions for custom implementations // have this _generic or _avx2 suffix. -#ifdef __AVX2__ - -#include - namespace faiss { -template -inline float distance_single_code( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - return distance_single_code_avx2(M, nbits, sim_table, code); -} - -template -inline void distance_four_codes( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_avx2( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} +// definiton and default implementation +template +struct PQCodeDistance { + using PQDecoder = PQDecoderT; + + /// Returns the distance to a single code. + static float distance_single_code( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // the code + const uint8_t* code) { + PQDecoderT decoder(code, nbits); + const size_t ksub = 1 << nbits; + + const float* tab = sim_table; + float result = 0; + + for (size_t m = 0; m < M; m++) { + result += tab[decoder.decode()]; + tab += ksub; + } + + return result; + } + + /// Combines 4 operations of distance_single_code() + /// General-purpose version. + static void distance_four_codes( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + PQDecoderT decoder0(code0, nbits); + PQDecoderT decoder1(code1, nbits); + PQDecoderT decoder2(code2, nbits); + PQDecoderT decoder3(code3, nbits); + const size_t ksub = 1 << nbits; + + const float* tab = sim_table; + result0 = 0; + result1 = 0; + result2 = 0; + result3 = 0; + + for (size_t m = 0; m < M; m++) { + result0 += tab[decoder0.decode()]; + result1 += tab[decoder1.decode()]; + result2 += tab[decoder2.decode()]; + result3 += tab[decoder3.decode()]; + tab += ksub; + } + } +}; } // namespace faiss - -#elif defined(__ARM_FEATURE_SVE) - -#include - -namespace faiss { - -template -inline float distance_single_code( - // the product quantizer - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - return distance_single_code_sve(M, nbits, sim_table, code); -} - -template -inline void distance_four_codes( - // the product quantizer - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_sve( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -} // namespace faiss - -#else - -#include - -namespace faiss { - -template -inline float distance_single_code( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - return distance_single_code_generic(M, nbits, sim_table, code); -} - -template -inline void distance_four_codes( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_generic( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -} // namespace faiss - -#endif diff --git a/faiss/utils/simdlib_avx2.h b/faiss/utils/simd_impl/simdlib_avx2.h similarity index 100% rename from faiss/utils/simdlib_avx2.h rename to faiss/utils/simd_impl/simdlib_avx2.h diff --git a/faiss/utils/simdlib_avx512.h b/faiss/utils/simd_impl/simdlib_avx512.h similarity index 100% rename from faiss/utils/simdlib_avx512.h rename to faiss/utils/simd_impl/simdlib_avx512.h diff --git a/faiss/utils/simdlib_emulated.h b/faiss/utils/simd_impl/simdlib_emulated.h similarity index 100% rename from faiss/utils/simdlib_emulated.h rename to faiss/utils/simd_impl/simdlib_emulated.h diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simd_impl/simdlib_neon.h similarity index 100% rename from faiss/utils/simdlib_neon.h rename to faiss/utils/simd_impl/simdlib_neon.h diff --git a/faiss/utils/simdlib_ppc64.h b/faiss/utils/simd_impl/simdlib_ppc64.h similarity index 100% rename from faiss/utils/simdlib_ppc64.h rename to faiss/utils/simd_impl/simdlib_ppc64.h diff --git a/faiss/utils/simdlib.h b/faiss/utils/simdlib.h index eadfb78ae3..98c38f7a0d 100644 --- a/faiss/utils/simdlib.h +++ b/faiss/utils/simdlib.h @@ -21,20 +21,20 @@ #elif defined(__AVX2__) -#include +#include #elif defined(__aarch64__) -#include +#include #elif defined(__PPC64__) -#include +#include #else // emulated = all operations are implemented as scalars -#include +#include // FIXME: make a SSE version // is this ever going to happen? We will probably rather implement AVX512 diff --git a/tests/test_code_distance.cpp b/tests/test_code_distance.cpp index f1a3939388..e4b61baf0f 100644 --- a/tests/test_code_distance.cpp +++ b/tests/test_code_distance.cpp @@ -22,6 +22,7 @@ #include #include #include +#include size_t nMismatches( const std::vector& ref, @@ -80,8 +81,10 @@ void test( for (size_t k = 0; k < 10; k++) { #pragma omp parallel for schedule(guided) for (size_t i = 0; i < n; i++) { - resultsRef[i] = - faiss::distance_single_code_generic( + resultsRef[i] = faiss::PQCodeDistance< + faiss::PQDecoder8, + faiss::SIMDLevel::NONE>:: + distance_single_code( subq, 8, lookup.data(), codes.data() + subq * i); } } @@ -94,8 +97,10 @@ void test( for (size_t k = 0; k < 1000; k++) { #pragma omp parallel for schedule(guided) for (size_t i = 0; i < n; i++) { - resultsNewGeneric1x[i] = - faiss::distance_single_code_generic( + resultsNewGeneric1x[i] = faiss::PQCodeDistance< + faiss::PQDecoder8, + faiss::SIMDLevel::NONE>:: + distance_single_code( subq, 8, lookup.data(), @@ -117,18 +122,21 @@ void test( for (size_t k = 0; k < 1000; k++) { #pragma omp parallel for schedule(guided) for (size_t i = 0; i < n; i += 4) { - faiss::distance_four_codes_generic( - subq, - 8, - lookup.data(), - codes.data() + subq * (i + 0), - codes.data() + subq * (i + 1), - codes.data() + subq * (i + 2), - codes.data() + subq * (i + 3), - resultsNewGeneric4x[i + 0], - resultsNewGeneric4x[i + 1], - resultsNewGeneric4x[i + 2], - resultsNewGeneric4x[i + 3]); + faiss::PQCodeDistance< + faiss::PQDecoder8, + faiss::SIMDLevel::NONE>:: + distance_four_codes( + subq, + 8, + lookup.data(), + codes.data() + subq * (i + 0), + codes.data() + subq * (i + 1), + codes.data() + subq * (i + 2), + codes.data() + subq * (i + 3), + resultsNewGeneric4x[i + 0], + resultsNewGeneric4x[i + 1], + resultsNewGeneric4x[i + 2], + resultsNewGeneric4x[i + 3]); } } @@ -147,8 +155,10 @@ void test( for (size_t k = 0; k < 1000; k++) { #pragma omp parallel for schedule(guided) for (size_t i = 0; i < n; i++) { - resultsNewCustom1x[i] = - faiss::distance_single_code( + resultsNewCustom1x[i] = faiss::PQCodeDistance< + faiss::PQDecoder8, + faiss::SIMDLevel::NONE>:: + distance_single_code( subq, 8, lookup.data(), @@ -170,18 +180,21 @@ void test( for (size_t k = 0; k < 1000; k++) { #pragma omp parallel for schedule(guided) for (size_t i = 0; i < n; i += 4) { - faiss::distance_four_codes( - subq, - 8, - lookup.data(), - codes.data() + subq * (i + 0), - codes.data() + subq * (i + 1), - codes.data() + subq * (i + 2), - codes.data() + subq * (i + 3), - resultsNewCustom4x[i + 0], - resultsNewCustom4x[i + 1], - resultsNewCustom4x[i + 2], - resultsNewCustom4x[i + 3]); + faiss::PQCodeDistance< + faiss::PQDecoder8, + faiss::SIMDLevel::NONE>:: + distance_four_codes( + subq, + 8, + lookup.data(), + codes.data() + subq * (i + 0), + codes.data() + subq * (i + 1), + codes.data() + subq * (i + 2), + codes.data() + subq * (i + 3), + resultsNewCustom4x[i + 0], + resultsNewCustom4x[i + 1], + resultsNewCustom4x[i + 2], + resultsNewCustom4x[i + 3]); } }