diff --git a/src/op/reduce.cc b/src/op/reduce.cc index b95c6cb4c..39b1e2377 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -420,12 +420,23 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared"); std::stringstream ss; auto threads = T.thread_bounds->extent; - ss << "tl::CumSum2D<" << threads << ", " << dim << ", " - << (reverse ? "true" : "false") << ">::run"; - Array args = {StringImm(ss.str()), src.access_ptr(1), - dst.access_ptr(3)}; - for (int i = 0; i < src->shape.size(); i++) { - args.push_back(src->shape[i]); + Array args; + int ndim = static_cast(src->shape.size()); + if (ndim == 1) { + ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim " + "= 0."; + ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false") + << ">::run"; + args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3), + src->shape[0]}; + } else if (ndim == 2) { + ss << "tl::CumSum2D<" << threads << ", " << dim << ", " + << (reverse ? "true" : "false") << ">::run"; + args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3), + src->shape[0], src->shape[1]}; + } else { + LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got " + << ndim << "D."; } return Evaluate(Call(dst->dtype, builtin::call_extern(), args)); } else { @@ -446,4 +457,4 @@ TIR_REGISTER_TL_OP(CumSumOp, cumsum) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace tl -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/tl_templates/cuda/reduce.h b/src/tl_templates/cuda/reduce.h index 2783fc536..d3ce47bd0 100644 --- a/src/tl_templates/cuda/reduce.h +++ b/src/tl_templates/cuda/reduce.h @@ -68,6 +68,74 @@ struct AllReduce { } }; +template struct CumSum1D { + static_assert(threads == 1024 or threads == 512 or threads == 256 or + threads == 128 or threads == 64 or threads == 32); + template + static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst, + int N) { + if (N <= 0) + return; + + constexpr unsigned MASK = 0xffffffff; + const int tid = threadIdx.x; + const int lane = tid % SEG; + + if (tid >= SEG) + return; + + T carry = (T)0; + + if (reverse) { + const int num_segments = (N + SEG - 1) / SEG; + for (int seg = num_segments - 1; seg >= 0; --seg) { + const int idx = seg * SEG + lane; + T val = (idx < N) ? src[idx] : (T)0; + +#pragma unroll + for (int off = 1; off < SEG; off <<= 1) { + T n = (T)__shfl_down_sync(MASK, val, off); + if (lane < SEG - off) + val += n; + } + + val += carry; + + if (idx < N) + dst[idx] = val; + + T segSum = (T)__shfl_sync(MASK, val, 0); + if (lane == 0) + carry = segSum; + carry = (T)__shfl_sync(MASK, carry, 0); + } + } else { + const int num_segments = (N + SEG - 1) / SEG; + for (int seg = 0; seg < num_segments; ++seg) { + const int idx = seg * SEG + lane; + T val = (idx < N) ? src[idx] : (T)0; + +#pragma unroll + for (int off = 1; off < SEG; off <<= 1) { + T n = (T)__shfl_up_sync(MASK, val, off); + if (lane >= off) + val += n; + } + + val += carry; + + if (idx < N) + dst[idx] = val; + + T segSum = (T)__shfl_sync(MASK, val, SEG - 1); + if (lane == SEG - 1) + carry = segSum; + carry = (T)__shfl_sync(MASK, carry, SEG - 1); + } + } + } +}; + template struct CumSum2D { static_assert(threads == 1024 or threads == 512 or threads == 256 or threads == 128 or threads == 64 or threads == 32); diff --git a/testing/python/language/test_tilelang_language_cumsum.py b/testing/python/language/test_tilelang_language_cumsum.py index c6e75252e..004640535 100644 --- a/testing/python/language/test_tilelang_language_cumsum.py +++ b/testing/python/language/test_tilelang_language_cumsum.py @@ -71,6 +71,75 @@ def ref_program(A): torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3) +def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"): + import tilelang.language as T + + @T.prim_func + def cumsum( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: + A_shared = T.alloc_shared((block_N,), dtype) + + T.copy(A[bx * block_N], A_shared) + T.cumsum(src=A_shared, dim=0, reverse=reverse) + T.copy(A_shared, B[bx * block_N]) + + return cumsum + + +def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"): + import tilelang.language as T + + @T.prim_func + def cumsum( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: + A_shared = T.alloc_shared((block_N,), dtype) + A_fragment = T.alloc_fragment((block_N,), dtype) + + T.copy(A[bx * block_N], A_shared) + T.copy(A_shared, A_fragment) + T.cumsum(src=A_fragment, dim=0, reverse=reverse) + T.copy(A_fragment, B[bx * block_N]) + + return cumsum + + +def run_cumsum_1d(N, block_N, reverse=False, dtype="float32", scope="smem"): + if scope == "smem": + program = cumsum_smem_test_1d(N, block_N, reverse, dtype) + elif scope == "fragment": + program = cumsum_fragment_test_1d(N, block_N, reverse, dtype) + else: + raise ValueError(f"Unknown scope {scope}") + + jit_kernel = tl.compile(program, out_idx=-1) + A = torch.randn(N, dtype=getattr(torch, dtype)).cuda() + + def ref_program(A): + ref_b = torch.empty_like(A) + num_blocks = (N + block_N - 1) // block_N + for j in range(num_blocks): + start = j * block_N + end = min(start + block_N, N) + chunk = A[start:end] + if reverse: + chunk = torch.flip(chunk, dims=[0]) + chunk = chunk.cumsum(dim=0) + if reverse: + chunk = torch.flip(chunk, dims=[0]) + ref_b[start:end] = chunk + return ref_b + + tilelang_res = jit_kernel(A) + ref_res = ref_program(A) + torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3) + + def test_cumsum_smem(): # Test different sizes run_cumsum(1024, 1024, 128, 128) @@ -92,5 +161,15 @@ def test_cumsum_fragment(): run_cumsum(256, 256, 128, 128, dtype="float32", scope="fragment") +def test_cumsum_smem_1d(): + run_cumsum_1d(1024, 128) + run_cumsum_1d(1024, 128, reverse=True) + + +def test_cumsum_fragment_1d(): + run_cumsum_1d(1024, 128, scope="fragment") + run_cumsum_1d(1024, 128, reverse=True, scope="fragment") + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index a43aa8b18..9c7510e4c 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -160,6 +160,29 @@ def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reve Negative `dim` indices are normalized (Python-style). If `dst` is None, the operation is performed in-place into `src`. Raises ValueError when `dim` is out of bounds for `src.shape`. When `src.scope() == "local.fragment"`, this delegates to `cumsum_fragment`; otherwise it emits the `tl.cumsum` intrinsic. + Examples: + A 1D inclusive scan that writes the result into a separate shared-memory buffer: + + >>> import tilelang.language as T + >>> @T.prim_func + ... def kernel(A: T.Tensor((128,), "float32"), B: T.Tensor((128,), "float32")): + ... with T.Kernel(1, threads=128): + ... A_shared = T.alloc_shared((128,), "float32") + ... T.copy(A, A_shared) + ... T.cumsum(src=A_shared, dst=A_shared, dim=0) + ... T.copy(A_shared, B) + + A 2D prefix sum along the last dimension with reverse accumulation: + + >>> import tilelang.language as T + >>> @T.prim_func + ... def kernel2d(A: T.Tensor((64, 64), "float16"), B: T.Tensor((64, 64), "float16")): + ... with T.Kernel(1, 1, threads=256): + ... tile = T.alloc_shared((64, 64), "float16") + ... T.copy(A, tile) + ... T.cumsum(src=tile, dim=1, reverse=True) + ... T.copy(tile, B) + Returns: tir.Call: A handle to the emitted cumulative-sum operation. """