Skip to content

Commit f767eb5

Browse files
committed
More C++17 build fixes, misc changes
1 parent 532c2ad commit f767eb5

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class FMHAFwdEpilogue {
9191
make_layout(select<1,0>(SGTileShapeO{}),
9292
Stride<E<1>, E<0>>{})
9393
));
94-
using ReduceFragARow = decltype(reduce<1>(ReduceFragA{}, sycl::plus{}));
94+
using ReduceFragARow = decltype(reduce<1>(ReduceFragA{}, sycl::plus<void>{}));
9595

9696
static auto default_tiled_copy_O_helper() {
9797
if constexpr (ReduceK{} == _1{})

applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,12 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
118118
make_identity_tensor(select<0,1>(TiledMMA{}.tile_mnk()))));
119119

120120
using FragS = FragC<TiledMMAQK>;
121-
using FragSRow = decltype(reduce<1>(FragS{}, sycl::plus{}));
121+
using FragSRow = decltype(reduce<1>(FragS{}, sycl::plus<void>{}));
122122
using ElementS = typename TiledMMAQK::ValTypeD;
123123

124124
using SingleFragA = FragC<TiledMMAPV>; // (atom val,q',v')
125125
using FragA = expand_sg_fragment_t<SingleFragA, 1, VTiles>; // (atom val,q',v',VV)
126-
using FragARow = decltype(reduce<1>(FragA{}, sycl::plus{}));
126+
using FragARow = decltype(reduce<1>(FragA{}, sycl::plus<void>{}));
127127
using ElementA = typename TiledMMAPV::ValTypeD;
128128

129129
static constexpr bool CausalMask = CausalMask_;
@@ -293,9 +293,11 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
293293
if (check_remainder_k && K == blk_k1 - 1) {
294294
FragSRow k_rem_mask;
295295
int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0];
296+
CUTLASS_PRAGMA_UNROLL
296297
for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) {
297298
k_rem_mask(i) = (k < shape<0>(K_2D)) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY);
298299
}
300+
CUTLASS_PRAGMA_UNROLL
299301
for (int i = 0; i < tSrS.size(); i++) {
300302
tSrS(i) = sycl::fmin(tSrS(i), broadcast<1>(k_rem_mask, tSrS, i));
301303
}
@@ -309,7 +311,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
309311
#if 0
310312
reorder(tSrS, tArP);
311313
#else
312-
for (int i = 0; i < tArP.size(); i++)
314+
for (int i = 0; i < tArP.size(); i++) // SYCL compiler currently is not correctly handling the above reorder.
313315
tArP(i) = static_cast<typename TiledMMAPV::ValTypeA>(tSrS(i));
314316
#endif
315317

@@ -370,7 +372,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
370372
}
371373

372374
/* Update sums */
373-
auto tS_bsum = reduce<1>(tS, sycl::plus{});
375+
auto tS_bsum = reduce<1>(tS, sycl::plus<void>{});
374376
for (int i = 0; i < tS_sum.size(); i++)
375377
tS_sum(i) += tS_bsum(i);
376378
}

0 commit comments

Comments
 (0)