Skip to content

Commit ad36900

Browse files
author
Raghuveer Devulapalli
committed
Add a way to build a shared library with dynamic dispatch
1 parent 527248c commit ad36900

11 files changed

+417
-2
lines changed

_clang-format

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ KeepEmptyLinesAtTheStartOfBlocks: true
6363
MacroBlockBegin: ''
6464
MacroBlockEnd: ''
6565
MaxEmptyLinesToKeep: 1
66-
NamespaceIndentation: None
66+
NamespaceIndentation: Inner
6767
PenaltyBreakAssignment: 2
6868
PenaltyBreakBeforeFirstCallParameter: 19
6969
PenaltyBreakComment: 300

lib/meson.build

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
libtargets = []
2+
3+
if cpp.has_argument('-march=skylake-avx512')
4+
libtargets += static_library('libskx',
5+
files(
6+
'x86simdsort-skx.cpp',
7+
),
8+
include_directories : [src],
9+
cpp_args : ['-O3', '-mavx512f', '-mavx512dq', '-mavx512vl'],
10+
)
11+
endif
12+
13+
if cpp.has_argument('-march=icelake-client')
14+
libtargets += static_library('libicl',
15+
files(
16+
'x86simdsort-icl.cpp',
17+
),
18+
include_directories : [src],
19+
cpp_args : ['-O3', '-mavx512f', '-mavx512vbmi2', '-mavx512bw', '-mavx512vl', '-mf16c'],
20+
)
21+
endif
22+
23+
if cancompilefp16
24+
libtargets += static_library('libspr',
25+
files(
26+
'x86simdsort-spr.cpp',
27+
),
28+
include_directories : [src],
29+
cpp_args : ['-O3', '-mavx512f', '-mavx512fp16', '-mavx512vbmi2'],
30+
)
31+
endif

lib/x86simdsort-icl.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// ICL specific routines:
2+
#include "avx512-16bit-qsort.hpp"
3+
#include "x86simdsort-internal.h"
4+
5+
namespace xss {
6+
namespace avx512 {
7+
template <>
8+
void qsort(uint16_t *arr, int64_t size)
9+
{
10+
avx512_qsort(arr, size);
11+
}
12+
template <>
13+
void qselect(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
14+
{
15+
avx512_qselect(arr, k, arrsize, hasnan);
16+
}
17+
template <>
18+
void partial_qsort(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
19+
{
20+
avx512_partial_qsort(arr, k, arrsize, hasnan);
21+
}
22+
template <>
23+
void qsort(int16_t *arr, int64_t size)
24+
{
25+
avx512_qsort(arr, size);
26+
}
27+
template <>
28+
void qselect(int16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
29+
{
30+
avx512_qselect(arr, k, arrsize, hasnan);
31+
}
32+
template <>
33+
void partial_qsort(int16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
34+
{
35+
avx512_partial_qsort(arr, k, arrsize, hasnan);
36+
}
37+
} // namespace avx512
38+
} // namespace xss

lib/x86simdsort-internal.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#ifndef XSS_ALL_METHODS
2+
#define XSS_ALL_METHODS
3+
#include <stdint.h>
4+
#include <vector>
5+
6+
namespace xss {
7+
namespace avx512 {
8+
// quicksort
9+
template <typename T>
10+
void qsort(T *arr, int64_t arrsize);
11+
// quickselect
12+
template <typename T>
13+
void qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false);
14+
// partial sort
15+
template <typename T>
16+
void partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false);
17+
// argsort
18+
template <typename T>
19+
std::vector<int64_t> argsort(T *arr, int64_t arrsize);
20+
// argselect
21+
template <typename T>
22+
std::vector<int64_t> argselect(T *arr, int64_t k, int64_t arrsize);
23+
} // namespace avx512
24+
namespace avx2 {
25+
// quicksort
26+
template <typename T>
27+
void qsort(T *arr, int64_t arrsize);
28+
// quickselect
29+
template <typename T>
30+
void qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false);
31+
// partial sort
32+
template <typename T>
33+
void partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false);
34+
// argsort
35+
template <typename T>
36+
std::vector<int64_t> argsort(T *arr, int64_t arrsize);
37+
// argselect
38+
template <typename T>
39+
std::vector<int64_t> argselect(T *arr, int64_t k, int64_t arrsize);
40+
} // namespace avx2
41+
namespace scalar {
42+
// quicksort
43+
template <typename T>
44+
void qsort(T *arr, int64_t arrsize);
45+
// quickselect
46+
template <typename T>
47+
void qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false);
48+
// partial sort
49+
template <typename T>
50+
void partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false);
51+
// argsort
52+
template <typename T>
53+
std::vector<int64_t> argsort(T *arr, int64_t arrsize);
54+
// argselect
55+
template <typename T>
56+
std::vector<int64_t> argselect(T *arr, int64_t k, int64_t arrsize);
57+
} // namespace scalar
58+
} // namespace xss
59+
#endif

lib/x86simdsort-scalar.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include <algorithm>
2+
#include <numeric>
3+
namespace xss {
4+
namespace scalar {
5+
/* TODO: handle NAN */
6+
template <typename T>
7+
void qsort(T *arr, int64_t arrsize)
8+
{
9+
std::sort(arr, arr + arrsize);
10+
}
11+
template <typename T>
12+
void qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan)
13+
{
14+
std::nth_element(arr, arr + k, arr + arrsize);
15+
}
16+
template <typename T>
17+
void partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan)
18+
{
19+
std::partial_sort(arr, arr + k, arr + arrsize);
20+
}
21+
template <typename T>
22+
std::vector<int64_t> argsort(T *arr, int64_t arrsize)
23+
{
24+
std::vector<int64_t> arg(arrsize);
25+
std::iota(arg.begin(), arg.end(), 0);
26+
std::sort(arg.begin(),
27+
arg.end(),
28+
[arr](int64_t left, int64_t right) -> bool {
29+
return arr[left] < arr[right];
30+
});
31+
return arg;
32+
}
33+
template <typename T>
34+
std::vector<int64_t> argselect(T *arr, int64_t k, int64_t arrsize)
35+
{
36+
std::vector<int64_t> arg(arrsize);
37+
std::iota(arg.begin(), arg.end(), 0);
38+
std::nth_element(arg.begin(),
39+
arg.begin() + k,
40+
arg.end(),
41+
[arr](int64_t left, int64_t right) -> bool {
42+
return arr[left] < arr[right];
43+
});
44+
return arg;
45+
}
46+
47+
} // namespace scalar
48+
} // namespace xss

lib/x86simdsort-skx.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// SKX specific routines:
2+
#include "avx512-32bit-qsort.hpp"
3+
#include "avx512-64bit-argsort.hpp"
4+
#include "avx512-64bit-qsort.hpp"
5+
#include "x86simdsort-internal.h"
6+
7+
#define DEFINE_ALL_METHODS(type) \
8+
template <> \
9+
void qsort(type *arr, int64_t arrsize) \
10+
{ \
11+
avx512_qsort(arr, arrsize); \
12+
} \
13+
template <> \
14+
void qselect(type *arr, int64_t k, int64_t arrsize, bool hasnan) \
15+
{ \
16+
avx512_qselect(arr, k, arrsize, hasnan); \
17+
} \
18+
template <> \
19+
void partial_qsort(type *arr, int64_t k, int64_t arrsize, bool hasnan) \
20+
{ \
21+
avx512_partial_qsort(arr, k, arrsize, hasnan); \
22+
} \
23+
template <> \
24+
std::vector<int64_t> argsort(type *arr, int64_t arrsize) \
25+
{ \
26+
return avx512_argsort(arr, arrsize); \
27+
} \
28+
template <> \
29+
std::vector<int64_t> argselect(type *arr, int64_t k, int64_t arrsize) \
30+
{ \
31+
return avx512_argselect(arr, k, arrsize); \
32+
}
33+
34+
namespace xss {
35+
namespace avx512 {
36+
DEFINE_ALL_METHODS(uint32_t)
37+
DEFINE_ALL_METHODS(int32_t)
38+
DEFINE_ALL_METHODS(float)
39+
DEFINE_ALL_METHODS(uint64_t)
40+
DEFINE_ALL_METHODS(int64_t)
41+
DEFINE_ALL_METHODS(double)
42+
} // namespace avx512
43+
} // namespace xss

lib/x86simdsort-spr.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// SPR specific routines:
2+
#include "avx512fp16-16bit-qsort.hpp"
3+
#include "x86simdsort-internal.h"
4+
5+
namespace xss {
6+
namespace avx512 {
7+
template <>
8+
void qsort(_Float16 *arr, int64_t size)
9+
{
10+
avx512_qsort(arr, size);
11+
}
12+
template <>
13+
void qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan)
14+
{
15+
avx512_qselect(arr, k, arrsize, hasnan);
16+
}
17+
template <>
18+
void partial_qsort(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan)
19+
{
20+
avx512_partial_qsort(arr, k, arrsize, hasnan);
21+
}
22+
} // namespace avx512
23+
} // namespace xss

lib/x86simdsort.cpp

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#include "x86simdsort.h"
2+
#include "x86simdsort-internal.h"
3+
#include "x86simdsort-scalar.h"
4+
#include <algorithm>
5+
#include <iostream>
6+
#include <string>
7+
8+
static int check_cpu_feature_support(std::string_view cpufeature)
9+
{
10+
if (cpufeature == "avx512_spr")
11+
return __builtin_cpu_supports("avx512f")
12+
&& __builtin_cpu_supports("avx512fp16")
13+
&& __builtin_cpu_supports("avx512vbmi2");
14+
else if (cpufeature == "avx512_icl")
15+
return __builtin_cpu_supports("avx512f")
16+
&& __builtin_cpu_supports("avx512vbmi2")
17+
&& __builtin_cpu_supports("avx512bw")
18+
&& __builtin_cpu_supports("avx512vl");
19+
else if (cpufeature == "avx512_skx")
20+
return __builtin_cpu_supports("avx512f")
21+
&& __builtin_cpu_supports("avx512dq")
22+
&& __builtin_cpu_supports("avx512vl");
23+
else if (cpufeature == "avx2")
24+
return __builtin_cpu_supports("avx2");
25+
26+
return 0;
27+
}
28+
29+
std::string_view
30+
find_preferred_cpu(std::initializer_list<std::string_view> cpulist)
31+
{
32+
for (auto cpu : cpulist) {
33+
if (check_cpu_feature_support(cpu)) return cpu;
34+
}
35+
return "scalar";
36+
}
37+
38+
constexpr bool
39+
dispatch_requested(std::string_view cpurequested,
40+
std::initializer_list<std::string_view> cpulist)
41+
{
42+
for (auto cpu : cpulist) {
43+
if (cpu.find(cpurequested) != std::string_view::npos) return true;
44+
}
45+
return false;
46+
}
47+
48+
#define CAT_(a, b) a ## b
49+
#define CAT(a, b) CAT_(a, b)
50+
51+
#define DECLARE_INTERNAL_qsort(TYPE) \
52+
static void (*internal_qsort##TYPE)(TYPE *, int64_t) = NULL; \
53+
template <> \
54+
void qsort(TYPE *arr, int64_t arrsize) \
55+
{ \
56+
(*internal_qsort##TYPE)(arr, arrsize); \
57+
}
58+
59+
#define DECLARE_INTERNAL_qselect(TYPE) \
60+
static void (*internal_qselect##TYPE)(TYPE *, int64_t, int64_t, bool) = NULL; \
61+
template <> \
62+
void qselect(TYPE *arr, int64_t k, int64_t arrsize, bool hasnan) \
63+
{ \
64+
(*internal_qselect##TYPE)(arr, k, arrsize, hasnan); \
65+
}
66+
67+
#define DECLARE_INTERNAL_partial_qsort(TYPE) \
68+
static void (*internal_partial_qsort##TYPE)(TYPE *, int64_t, int64_t, bool) = NULL; \
69+
template <> \
70+
void partial_qsort(TYPE *arr, int64_t k, int64_t arrsize, bool hasnan) \
71+
{ \
72+
(*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan); \
73+
}
74+
75+
#define DECLARE_INTERNAL_argsort(TYPE) \
76+
static std::vector<int64_t> (*internal_argsort##TYPE)(TYPE *, int64_t) = NULL; \
77+
template <> \
78+
std::vector<int64_t> argsort(TYPE *arr, int64_t arrsize) \
79+
{ \
80+
return (*internal_argsort##TYPE)(arr, arrsize); \
81+
}
82+
83+
#define DECLARE_INTERNAL_argselect(TYPE) \
84+
static std::vector<int64_t> (*internal_argselect##TYPE)(TYPE *, int64_t, int64_t) = NULL; \
85+
template <> \
86+
std::vector<int64_t> argselect(TYPE *arr, int64_t k, int64_t arrsize) \
87+
{ \
88+
return (*internal_argselect##TYPE)(arr, k, arrsize); \
89+
}
90+
91+
/* runtime dispatch mechanism */
92+
#define DISPATCH(func, TYPE, ...) \
93+
DECLARE_INTERNAL_##func(TYPE) \
94+
static __attribute__((constructor)) void CAT(CAT(resolve_, func), TYPE)(void) \
95+
{ \
96+
CAT(CAT(internal_, func), TYPE) = &xss::scalar::func<TYPE>; \
97+
__builtin_cpu_init(); \
98+
std::string_view preferred_cpu = find_preferred_cpu({__VA_ARGS__}); \
99+
if constexpr (dispatch_requested("avx512", {__VA_ARGS__})) { \
100+
if (preferred_cpu.find("avx512") != std::string_view::npos) { \
101+
CAT(CAT(internal_, func), TYPE) = &xss::avx512::func<TYPE>; \
102+
return; \
103+
} \
104+
} \
105+
else if constexpr (dispatch_requested("avx2", {__VA_ARGS__})) { \
106+
if (preferred_cpu.find("avx2") != std::string_view::npos) { \
107+
CAT(CAT(internal_, func), TYPE) = &xss::avx2::func<TYPE>; \
108+
return; \
109+
} \
110+
} \
111+
}
112+
113+
114+
115+
namespace x86simdsort {
116+
#ifdef __FLT16_MAX__
117+
DISPATCH(qsort, _Float16, "avx512_spr")
118+
DISPATCH(qselect, _Float16, "avx512_spr")
119+
DISPATCH(partial_qsort, _Float16, "avx512_spr")
120+
DISPATCH(argsort, _Float16, "none")
121+
DISPATCH(argselect, _Float16, "none")
122+
#endif
123+
124+
#define DISPATCH_ALL(func, ISA_16BIT, ISA_32BIT, ISA_64BIT) \
125+
DISPATCH(func, uint16_t, ISA_16BIT)\
126+
DISPATCH(func, int16_t, ISA_16BIT)\
127+
DISPATCH(func, float, ISA_32BIT)\
128+
DISPATCH(func, int32_t, ISA_32BIT)\
129+
DISPATCH(func, uint32_t, ISA_32BIT)\
130+
DISPATCH(func, int64_t, ISA_64BIT)\
131+
DISPATCH(func, uint64_t, ISA_64BIT)\
132+
DISPATCH(func, double, ISA_64BIT)\
133+
134+
DISPATCH_ALL(qsort, ("avx512_icl"), ("avx512_skx"), ("avx512_skx"))
135+
DISPATCH_ALL(qselect, ("avx512_icl"), ("avx512_skx"), ("avx512_skx"))
136+
DISPATCH_ALL(partial_qsort, ("avx512_icl"), ("avx512_skx"), ("avx512_skx"))
137+
DISPATCH_ALL(argsort, "none", "avx512_skx", "avx512_skx")
138+
DISPATCH_ALL(argselect, "none", "avx512_skx", "avx512_skx")
139+
140+
} // namespace simdsort

0 commit comments

Comments
 (0)