|
| 1 | +#include "x86simdsort.h" |
| 2 | +#include "x86simdsort-internal.h" |
| 3 | +#include "x86simdsort-scalar.h" |
| 4 | +#include <algorithm> |
| 5 | +#include <iostream> |
| 6 | +#include <string> |
| 7 | + |
| 8 | +static int check_cpu_feature_support(std::string_view cpufeature) |
| 9 | +{ |
| 10 | + if (cpufeature == "avx512_spr") |
| 11 | + return __builtin_cpu_supports("avx512f") |
| 12 | + && __builtin_cpu_supports("avx512fp16") |
| 13 | + && __builtin_cpu_supports("avx512vbmi2"); |
| 14 | + else if (cpufeature == "avx512_icl") |
| 15 | + return __builtin_cpu_supports("avx512f") |
| 16 | + && __builtin_cpu_supports("avx512vbmi2") |
| 17 | + && __builtin_cpu_supports("avx512bw") |
| 18 | + && __builtin_cpu_supports("avx512vl"); |
| 19 | + else if (cpufeature == "avx512_skx") |
| 20 | + return __builtin_cpu_supports("avx512f") |
| 21 | + && __builtin_cpu_supports("avx512dq") |
| 22 | + && __builtin_cpu_supports("avx512vl"); |
| 23 | + else if (cpufeature == "avx2") |
| 24 | + return __builtin_cpu_supports("avx2"); |
| 25 | + |
| 26 | + return 0; |
| 27 | +} |
| 28 | + |
| 29 | +std::string_view |
| 30 | +find_preferred_cpu(std::initializer_list<std::string_view> cpulist) |
| 31 | +{ |
| 32 | + for (auto cpu : cpulist) { |
| 33 | + if (check_cpu_feature_support(cpu)) return cpu; |
| 34 | + } |
| 35 | + return "scalar"; |
| 36 | +} |
| 37 | + |
| 38 | +constexpr bool |
| 39 | +dispatch_requested(std::string_view cpurequested, |
| 40 | + std::initializer_list<std::string_view> cpulist) |
| 41 | +{ |
| 42 | + for (auto cpu : cpulist) { |
| 43 | + if (cpu.find(cpurequested) != std::string_view::npos) return true; |
| 44 | + } |
| 45 | + return false; |
| 46 | +} |
| 47 | + |
| 48 | +#define CAT_(a, b) a ## b |
| 49 | +#define CAT(a, b) CAT_(a, b) |
| 50 | + |
| 51 | +#define DECLARE_INTERNAL_qsort(TYPE) \ |
| 52 | + static void (*internal_qsort##TYPE)(TYPE *, int64_t) = NULL; \ |
| 53 | + template <> \ |
| 54 | + void qsort(TYPE *arr, int64_t arrsize) \ |
| 55 | + { \ |
| 56 | + (*internal_qsort##TYPE)(arr, arrsize); \ |
| 57 | + } |
| 58 | + |
| 59 | +#define DECLARE_INTERNAL_qselect(TYPE) \ |
| 60 | + static void (*internal_qselect##TYPE)(TYPE *, int64_t, int64_t, bool) = NULL; \ |
| 61 | + template <> \ |
| 62 | + void qselect(TYPE *arr, int64_t k, int64_t arrsize, bool hasnan) \ |
| 63 | + { \ |
| 64 | + (*internal_qselect##TYPE)(arr, k, arrsize, hasnan); \ |
| 65 | + } |
| 66 | + |
| 67 | +#define DECLARE_INTERNAL_partial_qsort(TYPE) \ |
| 68 | + static void (*internal_partial_qsort##TYPE)(TYPE *, int64_t, int64_t, bool) = NULL; \ |
| 69 | + template <> \ |
| 70 | + void partial_qsort(TYPE *arr, int64_t k, int64_t arrsize, bool hasnan) \ |
| 71 | + { \ |
| 72 | + (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan); \ |
| 73 | + } |
| 74 | + |
| 75 | +#define DECLARE_INTERNAL_argsort(TYPE) \ |
| 76 | + static std::vector<int64_t> (*internal_argsort##TYPE)(TYPE *, int64_t) = NULL; \ |
| 77 | + template <> \ |
| 78 | + std::vector<int64_t> argsort(TYPE *arr, int64_t arrsize) \ |
| 79 | + { \ |
| 80 | + return (*internal_argsort##TYPE)(arr, arrsize); \ |
| 81 | + } |
| 82 | + |
| 83 | +#define DECLARE_INTERNAL_argselect(TYPE) \ |
| 84 | + static std::vector<int64_t> (*internal_argselect##TYPE)(TYPE *, int64_t, int64_t) = NULL; \ |
| 85 | + template <> \ |
| 86 | + std::vector<int64_t> argselect(TYPE *arr, int64_t k, int64_t arrsize) \ |
| 87 | + { \ |
| 88 | + return (*internal_argselect##TYPE)(arr, k, arrsize); \ |
| 89 | + } |
| 90 | + |
| 91 | +/* runtime dispatch mechanism */ |
| 92 | +#define DISPATCH(func, TYPE, ...) \ |
| 93 | + DECLARE_INTERNAL_##func(TYPE) \ |
| 94 | + static __attribute__((constructor)) void CAT(CAT(resolve_, func), TYPE)(void) \ |
| 95 | + { \ |
| 96 | + CAT(CAT(internal_, func), TYPE) = &xss::scalar::func<TYPE>; \ |
| 97 | + __builtin_cpu_init(); \ |
| 98 | + std::string_view preferred_cpu = find_preferred_cpu({__VA_ARGS__}); \ |
| 99 | + if constexpr (dispatch_requested("avx512", {__VA_ARGS__})) { \ |
| 100 | + if (preferred_cpu.find("avx512") != std::string_view::npos) { \ |
| 101 | + CAT(CAT(internal_, func), TYPE) = &xss::avx512::func<TYPE>; \ |
| 102 | + return; \ |
| 103 | + } \ |
| 104 | + } \ |
| 105 | + else if constexpr (dispatch_requested("avx2", {__VA_ARGS__})) { \ |
| 106 | + if (preferred_cpu.find("avx2") != std::string_view::npos) { \ |
| 107 | + CAT(CAT(internal_, func), TYPE) = &xss::avx2::func<TYPE>; \ |
| 108 | + return; \ |
| 109 | + } \ |
| 110 | + } \ |
| 111 | + } |
| 112 | + |
| 113 | + |
| 114 | + |
| 115 | +namespace x86simdsort { |
| 116 | +#ifdef __FLT16_MAX__ |
| 117 | +DISPATCH(qsort, _Float16, "avx512_spr") |
| 118 | +DISPATCH(qselect, _Float16, "avx512_spr") |
| 119 | +DISPATCH(partial_qsort, _Float16, "avx512_spr") |
| 120 | +DISPATCH(argsort, _Float16, "none") |
| 121 | +DISPATCH(argselect, _Float16, "none") |
| 122 | +#endif |
| 123 | + |
| 124 | +#define DISPATCH_ALL(func, ISA_16BIT, ISA_32BIT, ISA_64BIT) \ |
| 125 | + DISPATCH(func, uint16_t, ISA_16BIT)\ |
| 126 | + DISPATCH(func, int16_t, ISA_16BIT)\ |
| 127 | + DISPATCH(func, float, ISA_32BIT)\ |
| 128 | + DISPATCH(func, int32_t, ISA_32BIT)\ |
| 129 | + DISPATCH(func, uint32_t, ISA_32BIT)\ |
| 130 | + DISPATCH(func, int64_t, ISA_64BIT)\ |
| 131 | + DISPATCH(func, uint64_t, ISA_64BIT)\ |
| 132 | + DISPATCH(func, double, ISA_64BIT)\ |
| 133 | + |
| 134 | +DISPATCH_ALL(qsort, ("avx512_icl"), ("avx512_skx"), ("avx512_skx")) |
| 135 | +DISPATCH_ALL(qselect, ("avx512_icl"), ("avx512_skx"), ("avx512_skx")) |
| 136 | +DISPATCH_ALL(partial_qsort, ("avx512_icl"), ("avx512_skx"), ("avx512_skx")) |
| 137 | +DISPATCH_ALL(argsort, "none", "avx512_skx", "avx512_skx") |
| 138 | +DISPATCH_ALL(argselect, "none", "avx512_skx", "avx512_skx") |
| 139 | + |
| 140 | +} // namespace simdsort |
0 commit comments