diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index 7ca9f4e1c..3fd59d5ce 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -379,3 +379,91 @@ namespace cutlass { TL_DEVICE bfloat16_t fast_exp(bfloat16_t x) { return ::hexp(x); } } // namespace cutlass + +// +// Type-safe warp shuffle helpers for 16-bit float types +// These wrappers avoid relying on implicit conversions that may be disallowed +// (e.g., converting float -> cutlass::bfloat16_t) by explicitly promoting to +// float for the shuffle and then down-converting. +// +namespace tl { + +// Generic passthroughs +template +TL_DEVICE T shfl_xor_sync(unsigned mask, T val, int laneMask) { + return __shfl_xor_sync(mask, val, laneMask); +} + +template +TL_DEVICE T shfl_down_sync(unsigned mask, T val, int delta) { + return __shfl_down_sync(mask, val, delta); +} + +template +TL_DEVICE T shfl_up_sync(unsigned mask, T val, int delta) { + return __shfl_up_sync(mask, val, delta); +} + +template TL_DEVICE T shfl_sync(unsigned mask, T val, int srcLane) { + return __shfl_sync(mask, val, srcLane); +} + +// Specializations for cutlass::half_t +template <> +TL_DEVICE half_t shfl_xor_sync(unsigned mask, half_t val, int laneMask) { + float f = static_cast(val); + float r = __shfl_xor_sync(mask, f, laneMask); + return half_t(r); +} + +template <> +TL_DEVICE half_t shfl_down_sync(unsigned mask, half_t val, int delta) { + float f = static_cast(val); + float r = __shfl_down_sync(mask, f, delta); + return half_t(r); +} + +template <> +TL_DEVICE half_t shfl_up_sync(unsigned mask, half_t val, int delta) { + float f = static_cast(val); + float r = __shfl_up_sync(mask, f, delta); + return half_t(r); +} + +template <> TL_DEVICE half_t shfl_sync(unsigned mask, half_t val, int srcLane) { + float f = static_cast(val); + float r = __shfl_sync(mask, f, srcLane); + return half_t(r); +} + +// Specializations for cutlass::bfloat16_t +template <> +TL_DEVICE bfloat16_t shfl_xor_sync(unsigned mask, bfloat16_t val, + int laneMask) { + float f = static_cast(val); + float r = __shfl_xor_sync(mask, f, laneMask); + return bfloat16_t(r); +} + +template <> +TL_DEVICE bfloat16_t shfl_down_sync(unsigned mask, bfloat16_t val, int delta) { + float f = static_cast(val); + float r = __shfl_down_sync(mask, f, delta); + return bfloat16_t(r); +} + +template <> +TL_DEVICE bfloat16_t shfl_up_sync(unsigned mask, bfloat16_t val, int delta) { + float f = static_cast(val); + float r = __shfl_up_sync(mask, f, delta); + return bfloat16_t(r); +} + +template <> +TL_DEVICE bfloat16_t shfl_sync(unsigned mask, bfloat16_t val, int srcLane) { + float f = static_cast(val); + float r = __shfl_sync(mask, f, srcLane); + return bfloat16_t(r); +} + +} // namespace tl diff --git a/src/tl_templates/cuda/reduce.h b/src/tl_templates/cuda/reduce.h index 331da6dc8..aa0cc83e8 100644 --- a/src/tl_templates/cuda/reduce.h +++ b/src/tl_templates/cuda/reduce.h @@ -102,7 +102,7 @@ struct AllReduce { __syncthreads(); x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]); } else { - x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset))); + x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset)); } if constexpr (offset == scale) { return x; @@ -122,7 +122,7 @@ struct AllReduce { asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads)); x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]); } else { - x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset))); + x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset)); } if constexpr (offset == scale) { return x; @@ -234,7 +234,7 @@ template struct CumSum2D { #pragma unroll for (int off = 1; off < SEG; off <<= 1) { - T n = (T)__shfl_down_sync(MASK, val, off); + T n = tl::shfl_down_sync(MASK, val, off); if (lane < SEG - off) val += n; } @@ -244,10 +244,10 @@ template struct CumSum2D { if (real_col < W) dst[real_row * W + real_col] = val; - T segSum = (T)__shfl_sync(MASK, val, (T)0); + T segSum = tl::shfl_sync(MASK, val, 0); if (lane == 0) carry = segSum; - carry = (T)__shfl_sync(MASK, carry, (T)0); + carry = tl::shfl_sync(MASK, carry, 0); } } else { for (int seg = 0; seg * SEG < W; ++seg) { @@ -260,7 +260,7 @@ template struct CumSum2D { #pragma unroll for (int off = 1; off < SEG; off <<= 1) { - T n = (T)__shfl_up_sync(MASK, val, off); + T n = tl::shfl_up_sync(MASK, val, off); if (lane >= off) val += n; } @@ -270,10 +270,10 @@ template struct CumSum2D { if (real_col < W) dst[real_row * W + real_col] = val; - T segSum = (T)__shfl_sync(MASK, val, SEG - 1); + T segSum = tl::shfl_sync(MASK, val, SEG - 1); if (lane == SEG - 1) carry = segSum; - carry = (T)__shfl_sync(MASK, carry, SEG - 1); + carry = tl::shfl_sync(MASK, carry, SEG - 1); } } }