Skip to content

Commit ad86ffa

Browse files
committed
Using templates to force unrolling and removing branches from hot loops
1 parent c131f98 commit ad86ffa

File tree

2 files changed

+145
-35
lines changed

2 files changed

+145
-35
lines changed

include/common/utils.h

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <array>
4+
#include <cstdint>
45
#include <tuple>
56
#include <type_traits>
67
#include <utility>
@@ -100,7 +101,7 @@ template<typename Tuple> auto extract_vals(const Tuple &t) {
100101
}
101102

102103
template<typename Tuple, std::size_t... I>
103-
auto extract_seqs_impl(const Tuple &t, std::index_sequence<I...>) {
104+
auto extract_seqs_impl(const Tuple &, std::index_sequence<I...>) {
104105
using T = std::remove_reference_t<Tuple>;
105106
return std::make_tuple(typename std::tuple_element_t<I, T>::seq_type{}...);
106107
}
@@ -145,5 +146,104 @@ decltype(auto) dispatch(Func &&func, ParamTuple &&params, Args &&...args) {
145146
if constexpr (!std::is_void_v<result_t>) return caller.result;
146147
}
147148

149+
// ----------------------------------------------
150+
// compile-time range length for [Start, Stop) with step Inc
151+
// ----------------------------------------------
152+
template<std::int64_t Start, std::int64_t Stop, std::int64_t Inc>
153+
inline constexpr std::int64_t compute_range_count = [] {
154+
static_assert(Inc != 0, "Inc must not be zero");
155+
if constexpr (Inc > 0) {
156+
const std::int64_t d = Stop - Start;
157+
return d > 0 ? ((d + Inc - 1) / Inc) : 0;
158+
} else {
159+
const std::int64_t d = Start - Stop;
160+
const std::int64_t s = -Inc;
161+
return d > 0 ? ((d + s - 1) / s) : 0;
162+
}
163+
}();
164+
165+
// ----------------------------------------------
166+
// low-level emitters (C++17, no templated lambdas)
167+
// ----------------------------------------------
168+
template<std::int64_t Start, std::int64_t Inc, class F, std::size_t... Is>
169+
constexpr void static_loop_impl_block(F &&f, std::index_sequence<Is...>) {
170+
(f(std::integral_constant<std::int64_t, Start + static_cast<std::int64_t>(Is) * Inc>{}),
171+
...);
172+
}
173+
174+
template<std::int64_t Start, std::int64_t Inc, std::int64_t K, class F, std::size_t... Bs>
175+
constexpr void static_loop_emit_all_blocks(F &&f, std::index_sequence<Bs...>) {
176+
static_assert(K > 0, "UNROLL K must be positive");
177+
(static_loop_impl_block<Start + static_cast<std::int64_t>(Bs) * K * Inc, Inc>(
178+
f, std::make_index_sequence<static_cast<std::size_t>(K)>{}),
179+
...);
180+
}
181+
182+
// ----------------------------------------------
183+
// static_loop
184+
// ----------------------------------------------
185+
template<std::int64_t Start, std::int64_t Stop, std::int64_t Inc = 1,
186+
std::int64_t UNROLL = compute_range_count<Start, Stop, Inc>, class F>
187+
constexpr void static_loop(F &&f) {
188+
static_assert(Inc != 0, "Inc must not be zero");
189+
190+
constexpr std::int64_t Count = compute_range_count<Start, Stop, Inc>;
191+
if constexpr (Count == 0) {
192+
// do nothing
193+
} else {
194+
constexpr std::int64_t k = (UNROLL > 0 ? UNROLL : Count);
195+
static_assert(k > 0, "internal: k must be positive");
196+
constexpr std::int64_t Blocks = Count / k;
197+
constexpr std::int64_t Tail = Count % k;
198+
199+
if constexpr (Blocks > 0) {
200+
static_loop_emit_all_blocks<Start, Inc, k>(
201+
std::forward<F>(f),
202+
std::make_index_sequence<static_cast<std::size_t>(Blocks)>{});
203+
}
204+
if constexpr (Tail > 0) {
205+
constexpr std::int64_t TailStart = Start + Blocks * k * Inc;
206+
static_loop_impl_block<TailStart, Inc>(
207+
std::forward<F>(f), std::make_index_sequence<static_cast<std::size_t>(Tail)>{});
208+
}
209+
}
210+
}
211+
212+
// convenience: Stop only => Start=0, Inc=1
213+
template<std::int64_t Stop, class F> constexpr void static_loop(F &&f) {
214+
static_loop<0, Stop, 1, compute_range_count<0, Stop, 1>>(std::forward<F>(f));
215+
}
216+
217+
// ----------------------------------------------
218+
// static_for wrappers expecting f.template operator()<I>()
219+
// keeps C++17 by adapting to integral_constant form above
220+
// ----------------------------------------------
221+
namespace detail {
222+
template<class F> struct as_template_index {
223+
F *pf;
224+
constexpr explicit as_template_index(F &f) : pf(&f) {}
225+
template<std::int64_t I>
226+
constexpr void operator()(std::integral_constant<std::int64_t, I>) const {
227+
pf->template operator()<I>();
228+
}
229+
};
230+
template<class T>
231+
using is_integral_cx = std::integral_constant<bool, std::is_integral_v<T>>;
232+
} // namespace detail
233+
234+
template<auto Start, auto End, class F> constexpr void static_for(F &&f) {
235+
static_assert(detail::is_integral_cx<decltype(Start)>::value &&
236+
detail::is_integral_cx<decltype(End)>::value,
237+
"Start/End must be integral constant expressions");
238+
static_assert(End >= Start, "End must be >= Start");
239+
240+
constexpr auto S = static_cast<std::int64_t>(Start);
241+
constexpr auto E = static_cast<std::int64_t>(End);
242+
static_loop<S, E, 1>(detail::as_template_index<F>{f});
243+
}
244+
245+
template<auto End, class F> constexpr void static_for(F &&f) {
246+
static_for<0, End>(std::forward<F>(f));
247+
}
148248
} // namespace common
149249
} // namespace finufft

src/spreadinterp.cpp

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,10 @@ template<class T, uint8_t N> constexpr auto GetPaddedSIMDWidth() {
9595
// that minimizes the number of iterations
9696
return xsimd::make_sized_batch<T, find_optimal_simd_width<T, N>()>::type::size;
9797
}
98+
9899
template<class T, uint8_t N>
99100
using PaddedSIMD = typename xsimd::make_sized_batch<T, GetPaddedSIMDWidth<T, N>()>::type;
101+
100102
template<class T, uint8_t ns> constexpr auto get_padding() {
101103
// helper function to get the padding for the given number of elements
102104
// ns is known at compile time, rounds ns to the next multiple of the SIMD width
@@ -130,8 +132,10 @@ template<class T> uint8_t get_padding(uint8_t ns) {
130132
// that's why is hardcoded here
131133
return get_padding_helper<T, 2 * MAX_NSPREAD>(ns);
132134
}
135+
133136
template<class T, uint8_t N>
134137
using BestSIMD = typename decltype(BestSIMDHelper<T, N, xsimd::batch<T>::size>())::type;
138+
135139
template<class T, class V, size_t... Is>
136140
constexpr T generate_sequence_impl(V a, V b, index_sequence<Is...>) noexcept {
137141
// utility function to generate a sequence of a, b interleaved as function arguments
@@ -278,43 +282,46 @@ static void evaluate_kernel_vector(T *ker, T *args,
278282
}
279283

280284
template<typename T, uint8_t w, uint8_t upsampfact,
281-
class simd_type =
282-
xsimd::make_sized_batch_t<T, find_optimal_simd_width<T, w>()>> // aka ns
285+
class simd_type = xsimd::make_sized_batch_t<T, find_optimal_simd_width<T, w>()>>
283286
static FINUFFT_ALWAYS_INLINE void eval_kernel_vec_Horner(T *FINUFFT_RESTRICT ker, T x,
284287
const finufft_spread_opts &opts
285-
[[maybe_unused]]) noexcept
286-
/* Fill ker[] with Horner piecewise poly approx to [-w/2,w/2] ES kernel eval at
287-
x_j = x + j, for j=0,..,w-1. Thus x in [-w/2,-w/2+1]. w is aka ns.
288-
This is the current evaluation method, since it's faster (except i7 w=16).
289-
Two upsampfacs implemented. Params must match ref formula. Barnett 4/24/18 */
288+
[[maybe_unused]]) noexcept {
289+
/* Fill ker[] with Horner piecewise poly approx to [-w/2,w/2] ES kernel eval at
290+
x_j = x + j, for j=0,..,w-1. Thus x in [-w/2,-w/2+1]. w is aka ns.
291+
This is the current evaluation method, since it's faster (except i7 w=16).
292+
Two upsampfacs implemented. Params must match ref formula. Barnett 4/24/18 */
293+
// scale so local grid offset z in [-1,1]
294+
const T z = std::fma(T(2.0), x, T(w - 1));
295+
using arch_t = typename simd_type::arch_type;
296+
static constexpr auto alignment = arch_t::alignment();
297+
static constexpr auto simd_size = simd_type::size;
298+
static constexpr auto padded_ns = (w + simd_size - 1) & ~(simd_size - 1);
290299

291-
{
292-
// scale so local grid offset z in[-1,1]
293-
const T z = std::fma(T(2.0), x, T(w - 1));
294-
using arch_t = typename simd_type::arch_type;
295-
static constexpr auto alignment = arch_t::alignment();
296-
static constexpr auto simd_size = simd_type::size;
297-
static constexpr auto padded_ns = (w + simd_size - 1) & ~(simd_size - 1);
298300
static constexpr auto horner_coeffs = []() constexpr noexcept {
299301
if constexpr (upsampfact == 200) {
300302
return get_horner_coeffs_200<T, w>();
301-
} else if constexpr (upsampfact == 125) {
303+
} else { // upsampfact == 125
304+
// this way we can have a static assert that things went wrong
305+
static_assert(upsampfact == 125, "Unsupported upsampfact");
302306
return get_horner_coeffs_125<T, w>();
303307
}
304308
}();
305-
static constexpr auto nc = horner_coeffs.size();
306-
static constexpr auto use_ker_sym = (simd_size < w);
309+
static constexpr auto nc = horner_coeffs.size();
307310

308311
alignas(alignment) static constexpr auto padded_coeffs =
309312
pad_2D_array_with_zeros<T, nc, w, padded_ns>(horner_coeffs);
310313

311314
// use kernel symmetry trick if w > simd_size
315+
static constexpr bool use_ker_sym = (simd_size < w);
316+
317+
const simd_type zv{z};
318+
312319
if constexpr (use_ker_sym) {
313320
static constexpr uint8_t tail = w % simd_size;
314321
static constexpr uint8_t if_odd_degree = ((nc + 1) % 2);
315-
static constexpr uint8_t offset_start = tail ? w - tail : w - simd_size;
316322
static constexpr uint8_t end_idx = (w + (tail > 0)) / 2;
317-
const simd_type zv{z};
323+
static constexpr uint8_t offset_start = tail ? (w - tail) : (w - simd_size);
324+
318325
const auto z2v = zv * zv;
319326

320327
// some xsimd constant for shuffle or inverse
@@ -328,30 +335,32 @@ Two upsampfacs implemented. Params must match ref formula. Barnett 4/24/18 */
328335
}
329336
}();
330337

331-
// process simd vecs
332338
simd_type k_prev, k_sym{0};
333-
for (uint8_t i{0}, offset = offset_start; i < end_idx;
334-
i += simd_size, offset -= simd_size) {
335-
auto k_odd = [i]() constexpr noexcept {
339+
340+
static_loop<0, end_idx, simd_size>([&]([[maybe_unused]] const auto i) {
341+
constexpr auto offset = static_cast<uint8_t>(offset_start - i);
342+
auto k_odd = [i]() constexpr noexcept {
336343
if constexpr (if_odd_degree) {
337344
return simd_type::load_aligned(padded_coeffs[0].data() + i);
338345
} else {
339346
return simd_type{0};
340347
}
341348
}();
342349
auto k_even = simd_type::load_aligned(padded_coeffs[if_odd_degree].data() + i);
343-
for (uint8_t j{1 + if_odd_degree}; j < nc; j += 2) {
350+
351+
static_loop<1 + if_odd_degree, static_cast<std::int64_t>(nc), 2>([&](const auto j) {
344352
const auto cji_odd = simd_type::load_aligned(padded_coeffs[j].data() + i);
345353
const auto cji_even = simd_type::load_aligned(padded_coeffs[j + 1].data() + i);
346354
k_odd = xsimd::fma(k_odd, z2v, cji_odd);
347355
k_even = xsimd::fma(k_even, z2v, cji_even);
348-
}
356+
});
357+
349358
// left part
350359
xsimd::fma(k_odd, zv, k_even).store_aligned(ker + i);
360+
351361
// right part symmetric to the left part
352-
if (offset >= end_idx) {
362+
if constexpr (offset >= end_idx) {
353363
if constexpr (tail) {
354-
// to use aligned store, we need shuffle the previous k_sym and current k_sym
355364
k_prev = k_sym;
356365
k_sym = xsimd::fnma(k_odd, zv, k_even);
357366
xsimd::shuffle(k_sym, k_prev, shuffle_batch).store_aligned(ker + offset);
@@ -360,20 +369,21 @@ Two upsampfacs implemented. Params must match ref formula. Barnett 4/24/18 */
360369
.store_aligned(ker + offset);
361370
}
362371
}
363-
}
372+
});
373+
364374
} else {
365-
const simd_type zv(z);
366-
for (uint8_t i = 0; i < w; i += simd_size) {
375+
static_loop<0, w, simd_size>([&](const auto i) {
367376
auto k = simd_type::load_aligned(padded_coeffs[0].data() + i);
368-
for (uint8_t j = 1; j < nc; ++j) {
377+
378+
static_loop<1, static_cast<std::int64_t>(nc), 1>([&](const auto j) {
369379
const auto cji = simd_type::load_aligned(padded_coeffs[j].data() + i);
370380
k = xsimd::fma(k, zv, cji);
371-
}
381+
});
382+
372383
k.store_aligned(ker + i);
373-
}
384+
});
374385
}
375386
}
376-
377387
template<typename T, uint8_t ns>
378388
static void interp_line_wrap(T *FINUFFT_RESTRICT target, const T *du, const T *ker,
379389
const BIGINT i1, const UBIGINT N1) {

0 commit comments

Comments
 (0)