Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/op/finalize_reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args,
* - Builds index Vars for each output dimension.
* - Reads the layout's ReplicateExtent and:
* - if extent == 1, emits a no-op Evaluate(0);
* - otherwise constructs an AllReduce extern call (uses `run_hopper` when the
* compilation target is Hopper) with an optional workspace (allocated via
* T.AddWorkspace when reducing_threads >= 32) and stores the result via
* - otherwise constructs an AllReduce extern call (uses `NamedBarrier` when
* the compilation target is Hopper) with an optional workspace (allocated
* via T.AddWorkspace when reducing_threads >= 32) and stores the result via
* BufferStore.
* - Wraps the store in parallel outer For loops over each output dimension.
*
Expand Down Expand Up @@ -99,11 +99,11 @@ Stmt FinalizeReducerOpNode::Lower(const LowerArgs &T,
int reducing_threads = extent;
std::stringstream ss;
auto thread_offset = T.thread_bounds->min;
if (TargetIsHopper(T.target) || TargetIsSm100(T.target) ||
TargetIsSM120(T.target)) {
if (TargetHasSMVersionGE(T.target, 90)) {
auto all_threads = T.thread_bounds->extent;
ss << "tl::AllReduce<" << op_str << ", " << reducing_threads << ", " << 1
<< ", " << thread_offset << ", " << all_threads << ">::run_hopper";
<< ", " << thread_offset << ", tl::NamedBarrier<" << all_threads
<< ">>::run";
} else {
ss << "tl::AllReduce<" << op_str << ", " << reducing_threads << ", " << 1
<< ", " << thread_offset << ">::run";
Expand Down
31 changes: 22 additions & 9 deletions src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ static Fragment ComputeReducerLayout(const Fragment &src_layout, int dim) {
* Lowers a ReduceOpNode operating on fragment-scoped buffers into a sequence of
* TIR statements implementing: optional initialization, thread-local reduction
* (unrolled inner loops), inter-thread reduction via a runtime AllReduce call
* (Hopper-specific `run_hopper` variant when TargetIsHopper(T.target) is true),
* and an optional accumulation or copy back to the destination buffer when a
* temporary clear buffer is used.
* (Hopper targets use `NamedBarrier` instead of the default
* `SyncThreadsBarrier`), and an optional accumulation or copy back to the
* destination buffer when a temporary clear buffer is used.
*
* Behavior notes:
* - Only supports src and dst in "local.fragment" scope; otherwise it checks
Expand All @@ -206,7 +206,7 @@ static Fragment ComputeReducerLayout(const Fragment &src_layout, int dim) {
* reduction.
* - Performs iterator compression for local reduction loops using `analyzer`.
* - Detects parallel thread splitting from the normalized iterator sum and
* emits a call to a templated `tl::AllReduce<...>::run` (or `run_hopper`)
* emits a call to a templated `tl::AllReduce<...>::run`
* via `builtin::call_extern`. For sufficiently large reducing thread counts
* (> 32) a workspace is allocated via T.AddWorkspace and passed to the
* AllReduce call.
Expand Down Expand Up @@ -352,6 +352,16 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto mark = iter_split->source->source.as<Var>();
ICHECK(mark) << "Not a normalized iterator: " << iter_split->source;
if (mark.value().same_as(src_vars[this->dim]->var)) {
// `scale` is the stride of participating threads in the thread index
// space. When the thread-to-data mapping for the reduce dimension is
// normalized as threadIdx = source * scale + ...,
// * scale == 1 means threads are contiguous (0, 1, 2, ...),
// * scale > 1 means threads are interleaved (0, scale, 2*scale,
// ...).
// Both cases use the recursive XOR-butterfly reduce.
// `extent` is the number of distinct thread positions along the reduce
// dimension, so reducing_threads = extent * scale covers the full
// thread index range that participates in the reduction.
auto scale = as_const_int(iter_split->scale);
auto extent = as_const_int(iter_split->extent);
ICHECK(scale != nullptr && extent != nullptr);
Expand All @@ -362,22 +372,25 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
std::stringstream ss;

auto thread_offset = T.thread_bounds->min;
if (TargetIsHopper(T.target) || TargetIsSm100(T.target) ||
TargetIsSM120(T.target)) {
if (TargetHasSMVersionGE(T.target, 90)) {
auto all_threads = T.thread_bounds->extent;
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ", " << thread_offset
<< ", " << all_threads << ">::run_hopper";
<< ", tl::NamedBarrier<" << all_threads << ">>::run";
} else {
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ", " << thread_offset
<< ">::run";
}
Array<PrimExpr> thread_reduce_args = {
StringImm(ss.str()), BufferLoad(clear_buffer, red_indices)};
// The butterfly reduce path needs one shared-memory slot per
// thread in the block.
if (reducing_threads > 32) {
PrimExpr workspace = T.AddWorkspace(
*as_const_int(T.thread_bounds->extent), clear_buffer->dtype);
int workspace_size =
static_cast<int>(*as_const_int(T.thread_bounds->extent));
PrimExpr workspace =
T.AddWorkspace(workspace_size, clear_buffer->dtype);
thread_reduce_args.push_back(workspace);
}
auto call = Call(clear_buffer->dtype, builtin::call_extern(),
Expand Down
133 changes: 105 additions & 28 deletions src/tl_templates/cuda/reduce.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include "common.h"
#include <cuda_bf16.h>
#include <cuda_fp16.h>

#ifndef __CUDACC_RTC__
#include <cstdint>
Expand All @@ -9,6 +11,9 @@

namespace tl {

template <typename T, typename ReduceOp>
TL_DEVICE T warp_reduce(T value, ReduceOp op);

// Select a wider accumulator type for improved numerical accuracy.
// Default: accumulate in the same type. Specialize FP16/BF16 to float.
template <typename T> struct AccType {
Expand Down Expand Up @@ -57,48 +62,70 @@ struct BitXorOp {
}
};

// Barrier policy: wraps __syncthreads().
// The phase template parameter is ignored (all phases use the same barrier).
struct SyncThreadsBarrier {
template <int phase = 0> static TL_DEVICE void sync() { __syncthreads(); }
};

// Barrier policy: wraps named barrier (bar.sync) with compile-time phase IDs.
// Used on Hopper and later architectures where __syncthreads() cannot be used
// in certain contexts.
template <int all_threads> struct NamedBarrier {
template <int phase = 1> static TL_DEVICE void sync() {
asm volatile("bar.sync %0, %1;" : : "r"(phase), "r"(all_threads));
}
};

// AllReduce performs a cross-thread reduction over a group of `threads`
// threads.
//
// Template parameters:
// Reducer - binary reduction functor (e.g. SumOp, MaxOp).
// threads - number of threads that span the reduce dimension,
// equal to extent * scale.
// scale - stride of participating threads in the thread index space.
// When the thread-to-data mapping is normalized as
// threadIdx = source * scale + ...
// `scale` is the stride between consecutive logical
// participants in the reduce dimension.
// The recursion terminates when threads == scale, meaning
// each reduce group has been collapsed to a single thread.
// Uses a recursive XOR-butterfly pattern: at each level,
// offset >= 32 goes through shared memory + barrier,
// offset < 32 uses warp shuffle (shfl_xor_sync).
// thread_offset - base thread index offset within the block.
// Barrier - barrier policy type (SyncThreadsBarrier or
// NamedBarrier<N>).
template <class Reducer, int threads, int scale, int thread_offset = 0,
int all_threads = threads>
class Barrier = SyncThreadsBarrier>
struct AllReduce {
static_assert(threads == 1024 or threads == 512 or threads == 256 or
threads == 128 or threads == 64 or threads == 32 or
threads == 16 or threads == 8 or threads == 4 or threads == 2);
static_assert(threads % scale == 0);
template <typename T> static TL_DEVICE T run(T x, T *red_buf = nullptr) {
constexpr int offset = threads / 2;
if constexpr (offset >= 32) {
__syncthreads();
red_buf[threadIdx.x - thread_offset] = x;
__syncthreads();
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
} else {
x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
}
if constexpr (offset == scale) {
if constexpr (threads == scale) {
// Recursion base case: each reduce group has exactly one thread left.
return x;
} else {
return AllReduce<Reducer, offset, scale, thread_offset, all_threads>::run(
x, red_buf);
return butterfly_reduce(x, red_buf);
}
}

template <typename T>
static TL_DEVICE T run_hopper(T x, T *red_buf = nullptr) {
private:
template <typename T> static TL_DEVICE T butterfly_reduce(T x, T *red_buf) {
constexpr int offset = threads / 2;
if constexpr (offset >= 32) {
asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(all_threads));
Barrier::template sync<1>();
red_buf[threadIdx.x - thread_offset] = x;
// TODO(lei): maybe we can merge the two bar.sync into one?
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
Barrier::template sync<2>();
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
} else {
x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
}
if constexpr (offset == scale) {
return x;
} else {
return AllReduce<Reducer, offset, scale, thread_offset,
all_threads>::run_hopper(x, red_buf);
return AllReduce<Reducer, offset, scale, thread_offset, Barrier>::run(
x, red_buf);
}
}
};
Expand Down Expand Up @@ -250,14 +277,64 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
}
};

// Reference:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#reduction
template <typename T, typename ReduceOp>
TL_DEVICE T warp_reduce(T value, ReduceOp op) {
constexpr uint32_t mask = 0xffffffff;
value = op(value, __shfl_xor_sync(mask, value, 16));
value = op(value, __shfl_xor_sync(mask, value, 8));
value = op(value, __shfl_xor_sync(mask, value, 4));
value = op(value, __shfl_xor_sync(mask, value, 2));
value = op(value, __shfl_xor_sync(mask, value, 1));
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && \
(defined(__CUDA_ARCH_FEAT_SM100_ALL) || defined(__CUDA_ARCH_FEAT_SM100_F))
float value_cast = 0.0f;
if constexpr (std::is_same_v<T, half_t>) {
value_cast = __half2float(value);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
value_cast = __bfloat162float(value);
} else {
value_cast = static_cast<float>(value);
}
if constexpr (std::is_same_v<ReduceOp, MaxOp> && !std::is_integral_v<T>) {
float res;
asm("redux.sync.max.f32 %0, %1, %2;"
: "=f"(res)
: "f"(value_cast), "r"(mask));
return static_cast<T>(res);
} else if constexpr (std::is_same_v<ReduceOp, MinOp> &&
!std::is_integral_v<T>) {
float res;
asm("redux.sync.min.f32 %0, %1, %2;"
: "=f"(res)
: "f"(value_cast), "r"(mask));
return static_cast<T>(res);
}
#endif
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
auto run_reduce_sync = [&]<typename T_cast>(T_cast val) {
if constexpr (std::is_same_v<ReduceOp, SumOp>) {
return __reduce_add_sync(mask, val);
} else if constexpr (std::is_same_v<ReduceOp, MaxOp>) {
return __reduce_max_sync(mask, val);
} else if constexpr (std::is_same_v<ReduceOp, MinOp>) {
return __reduce_min_sync(mask, val);
} else if constexpr (std::is_same_v<ReduceOp, BitAndOp>) {
return __reduce_and_sync(mask, val);
} else if constexpr (std::is_same_v<ReduceOp, BitOrOp>) {
return __reduce_or_sync(mask, val);
} else if constexpr (std::is_same_v<ReduceOp, BitXorOp>) {
return __reduce_xor_sync(mask, val);
}
};

if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) {
return run_reduce_sync(value);
} else if constexpr (std::is_integral_v<T>) {
return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value)));
}
#endif
value = op(value, tl::shfl_xor_sync(mask, value, 16));
value = op(value, tl::shfl_xor_sync(mask, value, 8));
value = op(value, tl::shfl_xor_sync(mask, value, 4));
value = op(value, tl::shfl_xor_sync(mask, value, 2));
value = op(value, tl::shfl_xor_sync(mask, value, 1));
return value;
}

Expand Down
Loading
Loading