Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,91 @@ namespace cutlass {
TL_DEVICE
bfloat16_t fast_exp(bfloat16_t x) { return ::hexp(x); }
} // namespace cutlass

//
// Type-safe warp shuffle helpers for 16-bit float types
// These wrappers avoid relying on implicit conversions that may be disallowed
// (e.g., converting float -> cutlass::bfloat16_t) by explicitly promoting to
// float for the shuffle and then down-converting.
//
namespace tl {

// Generic passthroughs
template <typename T>
TL_DEVICE T shfl_xor_sync(unsigned mask, T val, int laneMask) {
return __shfl_xor_sync(mask, val, laneMask);
}

template <typename T>
TL_DEVICE T shfl_down_sync(unsigned mask, T val, int delta) {
return __shfl_down_sync(mask, val, delta);
}

template <typename T>
TL_DEVICE T shfl_up_sync(unsigned mask, T val, int delta) {
return __shfl_up_sync(mask, val, delta);
}

template <typename T> TL_DEVICE T shfl_sync(unsigned mask, T val, int srcLane) {
return __shfl_sync(mask, val, srcLane);
}

// Specializations for cutlass::half_t
template <>
TL_DEVICE half_t shfl_xor_sync(unsigned mask, half_t val, int laneMask) {
float f = static_cast<float>(val);
float r = __shfl_xor_sync(mask, f, laneMask);
return half_t(r);
}

template <>
TL_DEVICE half_t shfl_down_sync(unsigned mask, half_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_down_sync(mask, f, delta);
return half_t(r);
}

template <>
TL_DEVICE half_t shfl_up_sync(unsigned mask, half_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_up_sync(mask, f, delta);
return half_t(r);
}

template <> TL_DEVICE half_t shfl_sync(unsigned mask, half_t val, int srcLane) {
float f = static_cast<float>(val);
float r = __shfl_sync(mask, f, srcLane);
return half_t(r);
}

// Specializations for cutlass::bfloat16_t
template <>
TL_DEVICE bfloat16_t shfl_xor_sync(unsigned mask, bfloat16_t val,
int laneMask) {
float f = static_cast<float>(val);
float r = __shfl_xor_sync(mask, f, laneMask);
return bfloat16_t(r);
}

template <>
TL_DEVICE bfloat16_t shfl_down_sync(unsigned mask, bfloat16_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_down_sync(mask, f, delta);
return bfloat16_t(r);
}

template <>
TL_DEVICE bfloat16_t shfl_up_sync(unsigned mask, bfloat16_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_up_sync(mask, f, delta);
return bfloat16_t(r);
}

template <>
TL_DEVICE bfloat16_t shfl_sync(unsigned mask, bfloat16_t val, int srcLane) {
float f = static_cast<float>(val);
float r = __shfl_sync(mask, f, srcLane);
return bfloat16_t(r);
}

} // namespace tl
16 changes: 8 additions & 8 deletions src/tl_templates/cuda/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ struct AllReduce {
__syncthreads();
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
} else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
}
if constexpr (offset == scale) {
return x;
Expand All @@ -122,7 +122,7 @@ struct AllReduce {
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
} else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
}
if constexpr (offset == scale) {
return x;
Expand Down Expand Up @@ -234,7 +234,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {

#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_down_sync(MASK, val, off);
T n = tl::shfl_down_sync(MASK, val, off);
if (lane < SEG - off)
val += n;
}
Expand All @@ -244,10 +244,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
if (real_col < W)
dst[real_row * W + real_col] = val;

T segSum = (T)__shfl_sync(MASK, val, (T)0);
T segSum = tl::shfl_sync(MASK, val, 0);
if (lane == 0)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, (T)0);
carry = tl::shfl_sync(MASK, carry, 0);
}
} else {
for (int seg = 0; seg * SEG < W; ++seg) {
Expand All @@ -260,7 +260,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {

#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_up_sync(MASK, val, off);
T n = tl::shfl_up_sync(MASK, val, off);
if (lane >= off)
val += n;
}
Expand All @@ -270,10 +270,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
if (real_col < W)
dst[real_row * W + real_col] = val;

T segSum = (T)__shfl_sync(MASK, val, SEG - 1);
T segSum = tl::shfl_sync(MASK, val, SEG - 1);
if (lane == SEG - 1)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, SEG - 1);
carry = tl::shfl_sync(MASK, carry, SEG - 1);
}
}
}
Expand Down
Loading