Skip to content

Commit 54d4bd6

Browse files
authored
[Bugfix] Support 16bits shfl_sync (#1169)
* Add type-safe warp shuffle helpers for 16-bit float types in common.h - Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`. - Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations. - Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability. * lint fix
1 parent 7a80b6d commit 54d4bd6

File tree

2 files changed

+96
-8
lines changed

2 files changed

+96
-8
lines changed

src/tl_templates/cuda/common.h

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,3 +379,91 @@ namespace cutlass {
379379
TL_DEVICE
380380
bfloat16_t fast_exp(bfloat16_t x) { return ::hexp(x); }
381381
} // namespace cutlass
382+
383+
//
384+
// Type-safe warp shuffle helpers for 16-bit float types
385+
// These wrappers avoid relying on implicit conversions that may be disallowed
386+
// (e.g., converting float -> cutlass::bfloat16_t) by explicitly promoting to
387+
// float for the shuffle and then down-converting.
388+
//
389+
namespace tl {
390+
391+
// Generic passthroughs
392+
template <typename T>
393+
TL_DEVICE T shfl_xor_sync(unsigned mask, T val, int laneMask) {
394+
return __shfl_xor_sync(mask, val, laneMask);
395+
}
396+
397+
template <typename T>
398+
TL_DEVICE T shfl_down_sync(unsigned mask, T val, int delta) {
399+
return __shfl_down_sync(mask, val, delta);
400+
}
401+
402+
template <typename T>
403+
TL_DEVICE T shfl_up_sync(unsigned mask, T val, int delta) {
404+
return __shfl_up_sync(mask, val, delta);
405+
}
406+
407+
template <typename T> TL_DEVICE T shfl_sync(unsigned mask, T val, int srcLane) {
408+
return __shfl_sync(mask, val, srcLane);
409+
}
410+
411+
// Specializations for cutlass::half_t
412+
template <>
413+
TL_DEVICE half_t shfl_xor_sync(unsigned mask, half_t val, int laneMask) {
414+
float f = static_cast<float>(val);
415+
float r = __shfl_xor_sync(mask, f, laneMask);
416+
return half_t(r);
417+
}
418+
419+
template <>
420+
TL_DEVICE half_t shfl_down_sync(unsigned mask, half_t val, int delta) {
421+
float f = static_cast<float>(val);
422+
float r = __shfl_down_sync(mask, f, delta);
423+
return half_t(r);
424+
}
425+
426+
template <>
427+
TL_DEVICE half_t shfl_up_sync(unsigned mask, half_t val, int delta) {
428+
float f = static_cast<float>(val);
429+
float r = __shfl_up_sync(mask, f, delta);
430+
return half_t(r);
431+
}
432+
433+
template <> TL_DEVICE half_t shfl_sync(unsigned mask, half_t val, int srcLane) {
434+
float f = static_cast<float>(val);
435+
float r = __shfl_sync(mask, f, srcLane);
436+
return half_t(r);
437+
}
438+
439+
// Specializations for cutlass::bfloat16_t
440+
template <>
441+
TL_DEVICE bfloat16_t shfl_xor_sync(unsigned mask, bfloat16_t val,
442+
int laneMask) {
443+
float f = static_cast<float>(val);
444+
float r = __shfl_xor_sync(mask, f, laneMask);
445+
return bfloat16_t(r);
446+
}
447+
448+
template <>
449+
TL_DEVICE bfloat16_t shfl_down_sync(unsigned mask, bfloat16_t val, int delta) {
450+
float f = static_cast<float>(val);
451+
float r = __shfl_down_sync(mask, f, delta);
452+
return bfloat16_t(r);
453+
}
454+
455+
template <>
456+
TL_DEVICE bfloat16_t shfl_up_sync(unsigned mask, bfloat16_t val, int delta) {
457+
float f = static_cast<float>(val);
458+
float r = __shfl_up_sync(mask, f, delta);
459+
return bfloat16_t(r);
460+
}
461+
462+
template <>
463+
TL_DEVICE bfloat16_t shfl_sync(unsigned mask, bfloat16_t val, int srcLane) {
464+
float f = static_cast<float>(val);
465+
float r = __shfl_sync(mask, f, srcLane);
466+
return bfloat16_t(r);
467+
}
468+
469+
} // namespace tl

src/tl_templates/cuda/reduce.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ struct AllReduce {
102102
__syncthreads();
103103
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
104104
} else {
105-
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
105+
x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
106106
}
107107
if constexpr (offset == scale) {
108108
return x;
@@ -122,7 +122,7 @@ struct AllReduce {
122122
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
123123
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
124124
} else {
125-
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
125+
x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
126126
}
127127
if constexpr (offset == scale) {
128128
return x;
@@ -234,7 +234,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
234234

235235
#pragma unroll
236236
for (int off = 1; off < SEG; off <<= 1) {
237-
T n = (T)__shfl_down_sync(MASK, val, off);
237+
T n = tl::shfl_down_sync(MASK, val, off);
238238
if (lane < SEG - off)
239239
val += n;
240240
}
@@ -244,10 +244,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
244244
if (real_col < W)
245245
dst[real_row * W + real_col] = val;
246246

247-
T segSum = (T)__shfl_sync(MASK, val, (T)0);
247+
T segSum = tl::shfl_sync(MASK, val, 0);
248248
if (lane == 0)
249249
carry = segSum;
250-
carry = (T)__shfl_sync(MASK, carry, (T)0);
250+
carry = tl::shfl_sync(MASK, carry, 0);
251251
}
252252
} else {
253253
for (int seg = 0; seg * SEG < W; ++seg) {
@@ -260,7 +260,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
260260

261261
#pragma unroll
262262
for (int off = 1; off < SEG; off <<= 1) {
263-
T n = (T)__shfl_up_sync(MASK, val, off);
263+
T n = tl::shfl_up_sync(MASK, val, off);
264264
if (lane >= off)
265265
val += n;
266266
}
@@ -270,10 +270,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
270270
if (real_col < W)
271271
dst[real_row * W + real_col] = val;
272272

273-
T segSum = (T)__shfl_sync(MASK, val, SEG - 1);
273+
T segSum = tl::shfl_sync(MASK, val, SEG - 1);
274274
if (lane == SEG - 1)
275275
carry = segSum;
276-
carry = (T)__shfl_sync(MASK, carry, SEG - 1);
276+
carry = tl::shfl_sync(MASK, carry, SEG - 1);
277277
}
278278
}
279279
}

0 commit comments

Comments
 (0)