diff --git a/src/op/finalize_reducer.cc b/src/op/finalize_reducer.cc index e9e2fca54..f65be3417 100644 --- a/src/op/finalize_reducer.cc +++ b/src/op/finalize_reducer.cc @@ -53,9 +53,9 @@ FinalizeReducerOp::FinalizeReducerOp(Array args, * - Builds index Vars for each output dimension. * - Reads the layout's ReplicateExtent and: * - if extent == 1, emits a no-op Evaluate(0); - * - otherwise constructs an AllReduce extern call (uses `run_hopper` when the - * compilation target is Hopper) with an optional workspace (allocated via - * T.AddWorkspace when reducing_threads >= 32) and stores the result via + * - otherwise constructs an AllReduce extern call (uses `NamedBarrier` when + * the compilation target is Hopper) with an optional workspace (allocated + * via T.AddWorkspace when reducing_threads >= 32) and stores the result via * BufferStore. * - Wraps the store in parallel outer For loops over each output dimension. * @@ -99,11 +99,11 @@ Stmt FinalizeReducerOpNode::Lower(const LowerArgs &T, int reducing_threads = extent; std::stringstream ss; auto thread_offset = T.thread_bounds->min; - if (TargetIsHopper(T.target) || TargetIsSm100(T.target) || - TargetIsSM120(T.target)) { + if (TargetHasSMVersionGE(T.target, 90)) { auto all_threads = T.thread_bounds->extent; ss << "tl::AllReduce<" << op_str << ", " << reducing_threads << ", " << 1 - << ", " << thread_offset << ", " << all_threads << ">::run_hopper"; + << ", " << thread_offset << ", tl::NamedBarrier<" << all_threads + << ">>::run"; } else { ss << "tl::AllReduce<" << op_str << ", " << reducing_threads << ", " << 1 << ", " << thread_offset << ">::run"; diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 3d5888492..c79f90fbe 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -191,9 +191,9 @@ static Fragment ComputeReducerLayout(const Fragment &src_layout, int dim) { * Lowers a ReduceOpNode operating on fragment-scoped buffers into a sequence of * TIR statements implementing: optional initialization, thread-local reduction * (unrolled inner loops), inter-thread reduction via a runtime AllReduce call - * (Hopper-specific `run_hopper` variant when TargetIsHopper(T.target) is true), - * and an optional accumulation or copy back to the destination buffer when a - * temporary clear buffer is used. + * (Hopper targets use `NamedBarrier` instead of the default + * `SyncThreadsBarrier`), and an optional accumulation or copy back to the + * destination buffer when a temporary clear buffer is used. * * Behavior notes: * - Only supports src and dst in "local.fragment" scope; otherwise it checks @@ -206,7 +206,7 @@ static Fragment ComputeReducerLayout(const Fragment &src_layout, int dim) { * reduction. * - Performs iterator compression for local reduction loops using `analyzer`. * - Detects parallel thread splitting from the normalized iterator sum and - * emits a call to a templated `tl::AllReduce<...>::run` (or `run_hopper`) + * emits a call to a templated `tl::AllReduce<...>::run` * via `builtin::call_extern`. For sufficiently large reducing thread counts * (> 32) a workspace is allocated via T.AddWorkspace and passed to the * AllReduce call. @@ -352,6 +352,16 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto mark = iter_split->source->source.as(); ICHECK(mark) << "Not a normalized iterator: " << iter_split->source; if (mark.value().same_as(src_vars[this->dim]->var)) { + // `scale` is the stride of participating threads in the thread index + // space. When the thread-to-data mapping for the reduce dimension is + // normalized as threadIdx = source * scale + ..., + // * scale == 1 means threads are contiguous (0, 1, 2, ...), + // * scale > 1 means threads are interleaved (0, scale, 2*scale, + // ...). + // Both cases use the recursive XOR-butterfly reduce. + // `extent` is the number of distinct thread positions along the reduce + // dimension, so reducing_threads = extent * scale covers the full + // thread index range that participates in the reduction. auto scale = as_const_int(iter_split->scale); auto extent = as_const_int(iter_split->extent); ICHECK(scale != nullptr && extent != nullptr); @@ -362,12 +372,11 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { std::stringstream ss; auto thread_offset = T.thread_bounds->min; - if (TargetIsHopper(T.target) || TargetIsSm100(T.target) || - TargetIsSM120(T.target)) { + if (TargetHasSMVersionGE(T.target, 90)) { auto all_threads = T.thread_bounds->extent; ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " << reducing_threads << ", " << (*scale) << ", " << thread_offset - << ", " << all_threads << ">::run_hopper"; + << ", tl::NamedBarrier<" << all_threads << ">>::run"; } else { ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " << reducing_threads << ", " << (*scale) << ", " << thread_offset @@ -375,9 +384,13 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } Array thread_reduce_args = { StringImm(ss.str()), BufferLoad(clear_buffer, red_indices)}; + // The butterfly reduce path needs one shared-memory slot per + // thread in the block. if (reducing_threads > 32) { - PrimExpr workspace = T.AddWorkspace( - *as_const_int(T.thread_bounds->extent), clear_buffer->dtype); + int workspace_size = + static_cast(*as_const_int(T.thread_bounds->extent)); + PrimExpr workspace = + T.AddWorkspace(workspace_size, clear_buffer->dtype); thread_reduce_args.push_back(workspace); } auto call = Call(clear_buffer->dtype, builtin::call_extern(), diff --git a/src/tl_templates/cuda/reduce.h b/src/tl_templates/cuda/reduce.h index 55b8878b7..e0a09cf7e 100644 --- a/src/tl_templates/cuda/reduce.h +++ b/src/tl_templates/cuda/reduce.h @@ -1,6 +1,8 @@ #pragma once #include "common.h" +#include +#include #ifndef __CUDACC_RTC__ #include @@ -9,6 +11,9 @@ namespace tl { +template +TL_DEVICE T warp_reduce(T value, ReduceOp op); + // Select a wider accumulator type for improved numerical accuracy. // Default: accumulate in the same type. Specialize FP16/BF16 to float. template struct AccType { @@ -57,39 +62,61 @@ struct BitXorOp { } }; +// Barrier policy: wraps __syncthreads(). +// The phase template parameter is ignored (all phases use the same barrier). +struct SyncThreadsBarrier { + template static TL_DEVICE void sync() { __syncthreads(); } +}; + +// Barrier policy: wraps named barrier (bar.sync) with compile-time phase IDs. +// Used on Hopper and later architectures where __syncthreads() cannot be used +// in certain contexts. +template struct NamedBarrier { + template static TL_DEVICE void sync() { + asm volatile("bar.sync %0, %1;" : : "r"(phase), "r"(all_threads)); + } +}; + +// AllReduce performs a cross-thread reduction over a group of `threads` +// threads. +// +// Template parameters: +// Reducer - binary reduction functor (e.g. SumOp, MaxOp). +// threads - number of threads that span the reduce dimension, +// equal to extent * scale. +// scale - stride of participating threads in the thread index space. +// When the thread-to-data mapping is normalized as +// threadIdx = source * scale + ... +// `scale` is the stride between consecutive logical +// participants in the reduce dimension. +// The recursion terminates when threads == scale, meaning +// each reduce group has been collapsed to a single thread. +// Uses a recursive XOR-butterfly pattern: at each level, +// offset >= 32 goes through shared memory + barrier, +// offset < 32 uses warp shuffle (shfl_xor_sync). +// thread_offset - base thread index offset within the block. +// Barrier - barrier policy type (SyncThreadsBarrier or +// NamedBarrier). template + class Barrier = SyncThreadsBarrier> struct AllReduce { - static_assert(threads == 1024 or threads == 512 or threads == 256 or - threads == 128 or threads == 64 or threads == 32 or - threads == 16 or threads == 8 or threads == 4 or threads == 2); static_assert(threads % scale == 0); template static TL_DEVICE T run(T x, T *red_buf = nullptr) { - constexpr int offset = threads / 2; - if constexpr (offset >= 32) { - __syncthreads(); - red_buf[threadIdx.x - thread_offset] = x; - __syncthreads(); - x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]); - } else { - x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset)); - } - if constexpr (offset == scale) { + if constexpr (threads == scale) { + // Recursion base case: each reduce group has exactly one thread left. return x; } else { - return AllReduce::run( - x, red_buf); + return butterfly_reduce(x, red_buf); } } - template - static TL_DEVICE T run_hopper(T x, T *red_buf = nullptr) { +private: + template static TL_DEVICE T butterfly_reduce(T x, T *red_buf) { constexpr int offset = threads / 2; if constexpr (offset >= 32) { - asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(all_threads)); + Barrier::template sync<1>(); red_buf[threadIdx.x - thread_offset] = x; - // TODO(lei): maybe we can merge the two bar.sync into one? - asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads)); + Barrier::template sync<2>(); x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]); } else { x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset)); @@ -97,8 +124,8 @@ struct AllReduce { if constexpr (offset == scale) { return x; } else { - return AllReduce::run_hopper(x, red_buf); + return AllReduce::run( + x, red_buf); } } }; @@ -250,14 +277,64 @@ template struct CumSum2D { } }; +// Reference: +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#reduction template TL_DEVICE T warp_reduce(T value, ReduceOp op) { constexpr uint32_t mask = 0xffffffff; - value = op(value, __shfl_xor_sync(mask, value, 16)); - value = op(value, __shfl_xor_sync(mask, value, 8)); - value = op(value, __shfl_xor_sync(mask, value, 4)); - value = op(value, __shfl_xor_sync(mask, value, 2)); - value = op(value, __shfl_xor_sync(mask, value, 1)); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && \ + (defined(__CUDA_ARCH_FEAT_SM100_ALL) || defined(__CUDA_ARCH_FEAT_SM100_F)) + float value_cast = 0.0f; + if constexpr (std::is_same_v) { + value_cast = __half2float(value); + } else if constexpr (std::is_same_v) { + value_cast = __bfloat162float(value); + } else { + value_cast = static_cast(value); + } + if constexpr (std::is_same_v && !std::is_integral_v) { + float res; + asm("redux.sync.max.f32 %0, %1, %2;" + : "=f"(res) + : "f"(value_cast), "r"(mask)); + return static_cast(res); + } else if constexpr (std::is_same_v && + !std::is_integral_v) { + float res; + asm("redux.sync.min.f32 %0, %1, %2;" + : "=f"(res) + : "f"(value_cast), "r"(mask)); + return static_cast(res); + } +#endif +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + auto run_reduce_sync = [&](T_cast val) { + if constexpr (std::is_same_v) { + return __reduce_add_sync(mask, val); + } else if constexpr (std::is_same_v) { + return __reduce_max_sync(mask, val); + } else if constexpr (std::is_same_v) { + return __reduce_min_sync(mask, val); + } else if constexpr (std::is_same_v) { + return __reduce_and_sync(mask, val); + } else if constexpr (std::is_same_v) { + return __reduce_or_sync(mask, val); + } else if constexpr (std::is_same_v) { + return __reduce_xor_sync(mask, val); + } + }; + + if constexpr (std::is_same_v || std::is_same_v) { + return run_reduce_sync(value); + } else if constexpr (std::is_integral_v) { + return static_cast(run_reduce_sync(static_cast(value))); + } +#endif + value = op(value, tl::shfl_xor_sync(mask, value, 16)); + value = op(value, tl::shfl_xor_sync(mask, value, 8)); + value = op(value, tl::shfl_xor_sync(mask, value, 4)); + value = op(value, tl::shfl_xor_sync(mask, value, 2)); + value = op(value, tl::shfl_xor_sync(mask, value, 1)); return value; } diff --git a/testing/python/language/test_tilelang_language_reduce.py b/testing/python/language/test_tilelang_language_reduce.py index f12c5bc4a..0bb0a088e 100644 --- a/testing/python/language/test_tilelang_language_reduce.py +++ b/testing/python/language/test_tilelang_language_reduce.py @@ -4,6 +4,7 @@ import tilelang.language as T tilelang.testing.set_random_seed() +tilelang.disable_cache() def _make_shared_reduce(M, N, dtype, reduce_cb): @@ -29,7 +30,7 @@ def _run_program(program, ref_program, atol=1e-2, rtol=1e-2): profiler.assert_allclose(ref_program, atol=atol, rtol=rtol) -def reduce_max_test(M, N, dtype=T.float16): +def reduce_test(M, N, dtype=T.float16, op="sum", threads=32): import tilelang.language as T @T.prim_func @@ -37,31 +38,27 @@ def main( A: T.Tensor((M, N), dtype), B: T.Tensor((M,), dtype), ): - with T.Kernel(1) as _: - A_local = T.alloc_fragment((M, N), dtype) - B_local = T.alloc_fragment((M,), dtype) - - T.copy(A, A_local) - T.reduce_max(A_local, B_local, dim=1) - T.copy(B_local, B) - - return main - - -def reduce_sum_test(M, N, dtype=T.float32): - import tilelang.language as T - - @T.prim_func - def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M,), dtype), - ): - with T.Kernel(1) as _: + with T.Kernel(1, threads=threads) as _: A_local = T.alloc_fragment((M, N), dtype) B_local = T.alloc_fragment((M,), dtype) T.copy(A, A_local) - T.reduce_sum(A_local, B_local, dim=1) + if op == "sum": + T.reduce_sum(A_local, B_local, dim=1) + elif op == "max": + T.reduce_max(A_local, B_local, dim=1) + elif op == "min": + T.reduce_min(A_local, B_local, dim=1) + elif op == "abssum": + T.reduce_abssum(A_local, B_local, dim=1) + elif op == "absmax": + T.reduce_absmax(A_local, B_local, dim=1) + elif op == "bitand": + T.reduce_bitand(A_local, B_local, dim=1) + elif op == "bitor": + T.reduce_bitor(A_local, B_local, dim=1) + elif op == "bitxor": + T.reduce_bitxor(A_local, B_local, dim=1) T.copy(B_local, B) return main @@ -87,14 +84,33 @@ def reduce_absmax_ss(M, N, dtype=T.float32): return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_absmax(src, dst, dim=1)) -def run_reduce_sum(M, N, dtype=T.float32, mode="rr"): +def run_reduce(M, N, dtype=T.float32, op="sum", mode="rr", threads=32): if mode == "rr": - program = reduce_sum_test(M, N, dtype) + program = reduce_test(M, N, dtype, op, threads) elif mode == "ss": + assert op == "sum", f"shared reduce only supports sum, got {op}" program = reduce_sum_ss(M, N, dtype) else: - raise NotImplementedError("run_reduce_sum only supports rr and ss") - _run_program(program, lambda A: A.sum(dim=1)) + raise NotImplementedError(f"run_reduce only supports rr and ss, got {mode}") + + import torch + + def ref_fn(A): + if op == "sum": + res = A.sum(dim=1) + elif op == "max": + res = A.max(dim=1).values + elif op == "min": + res = A.min(dim=1).values + elif op == "abssum": + res = A.abs().sum(dim=1) + elif op == "absmax": + res = A.abs().max(dim=1).values + if A.dtype in [torch.uint32, torch.int32, torch.int64]: + return res.to(A.dtype) + return res + + _run_program(program, ref_fn) def run_shared_reduce(program_builder, ref_program, M, N, dtype=T.float32): @@ -103,18 +119,32 @@ def run_shared_reduce(program_builder, ref_program, M, N, dtype=T.float32): def run_reduce_max(M, N, dtype=T.float16): - program = reduce_max_test(M, N, dtype) + program = reduce_test(M, N, dtype, "max") _run_program(program, lambda A: A.max(dim=1).values, atol=1e-2, rtol=1e-2) def test_reduce_sum(): - run_reduce_sum(256, 256) - run_reduce_sum(512, 128) - run_reduce_sum(128, 512) + MN_zip = [(256, 256), (512, 128), (128, 512)] + for dtype in [T.float32, T.int32, T.int64]: + for M, N in MN_zip: + run_reduce(M, N, dtype, "sum") + + +def test_reduce_other_op(): + MN_zip = [(256, 256), (512, 128)] + for op in ["max", "min", "abssum", "absmax"]: + for dtype in [T.float32, T.int32, T.int64]: + for M, N in MN_zip: + run_reduce(M, N, dtype, op) + + +def test_reduce_sum_threads(): + run_reduce(32, 32, T.float32, "sum", mode="rr", threads=16) + run_reduce(16, 16, T.float32, "sum", mode="rr", threads=8) def test_reduce_sum_shared(): - run_reduce_sum(64, 64, mode="ss") + run_reduce(64, 64, op="sum", mode="ss") def test_reduce_max():