diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 05dad48fc..b6dbe8651 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -16,6 +16,7 @@ #include "../transform/loop_partition.h" #include "region.h" #include "tir/transforms/ir_utils.h" +#include "tvm/tir/stmt.h" namespace tvm { namespace tl { @@ -57,12 +58,65 @@ static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, RegionOp region(call->args, vmap); return BufferRegion(region->GetBuffer(), region->GetRanges()); } + // builtin.tvm_access_ptr(...) — map var to Buffer and take full region + if (call->op.same_as(builtin::tvm_access_ptr())) { + Var var = Downcast(call->args[1]); + Buffer buf = vmap[var]; + Array ranges; + for (PrimExpr extent : buf->shape) { + ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); + } + return BufferRegion(buf, ranges); + } } LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg; throw; // Unreachable } +// Build a tvm_access_ptr(handle) to the start of the 2D tile within a +// BufferRegion. Offset is computed from all but the last two dimensions; extent +// is the product of the last two extents. rw_mask: 1=read, 2=write, +// 3=readwrite. +static PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, + int rw_mask) { + Buffer buf = region->buffer; + int ndim = static_cast(buf->shape.size()); + ICHECK(ndim == 1 || ndim == 2) << "Cumsum expects buffers with 1 or 2 dims"; + + PrimExpr offset, extent; + if (ndim == 1) { + // Simple 1D region: offset and extent come from the single axis. + auto axis = region->region[0]; + offset = axis->min; + extent = axis->extent; + } else { + // Compute row-major strides for ndim >= 2 + std::vector strides(ndim); + PrimExpr one = make_const(buf->shape[0].dtype(), 1); + PrimExpr cur = one; + for (int i = ndim - 1; i >= 0; --i) { + strides[i] = cur; + cur = cur * buf->shape[i]; + } + // Offset: sum_{i in [0..ndim-3]} min_i * stride_i + offset = make_const(buf->shape[0].dtype(), 0); + for (int i = 0; i < ndim - 2; ++i) { + offset = offset + region->region[i]->min * strides[i]; + } + + // Extent: last two extents product (elements) + extent = + region->region[ndim - 2]->extent * region->region[ndim - 1]->extent; + } + + // ptype and return handle + PrimExpr ptype = tir::TypeAnnotation(buf->dtype); + Array acc_args{ptype, buf->data, offset, extent, + IntImm(DataType::Int(32), rw_mask)}; + return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); +} + ReduceOp::ReduceOp(Array args, BufferMap vmap) { ObjectPtr node = tvm::ffi::make_object(); // Accept BufferRegion/BufferLoad/tl.region for src/dst @@ -231,6 +285,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto dst_scope = this->dst.scope(); if (src_scope == "local.fragment" && dst_scope == "local.fragment") { + Buffer src_buffer = get_buffer(this->src); Buffer dst_buffer = get_buffer(this->dst); Fragment src_layout = T.layout_map[this->src].as().value(); @@ -518,6 +573,16 @@ TIR_REGISTER_TL_OP(ReduceOp, reduce) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +// Normalize "Buffer" to BufferRegion. Use the shape of the buffer as the +// ranges. +static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) { + Array ranges; + for (PrimExpr extent : buf->shape) { + ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); + } + return BufferRegion(buf, ranges); +} + CumSumOp::CumSumOp(Array args, BufferMap vmap) { /// CumSum constructor arguments: /// - src: input buffer @@ -526,11 +591,19 @@ CumSumOp::CumSumOp(Array args, BufferMap vmap) { /// - reverse: whether to cumsum in reverse order CHECK_EQ(args.size(), 4); ObjectPtr node = tvm::ffi::make_object(); - node->src = vmap[GetVarFromAccessPtr(args[0])]; - node->dst = vmap[GetVarFromAccessPtr(args[1])]; + // node->src = vmap[GetVarFromAccessPtr(args[0])]; + // node->dst = vmap[GetVarFromAccessPtr(args[1])]; + node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap); + node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap); + node->src = node->srcRegion_->buffer; + node->dst = node->dstRegion_->buffer; node->dim = args[2].as().value()->value; node->reverse = args[3].as().value(); - CHECK_LT(node->dim, static_cast(node->src->shape.size())); + CHECK_LT(node->dim, static_cast(node->src->shape.size())) + << "The dim of cumsum should be less than the number of dimensions. Got " + "dim=" + << node->dim << ", but src has " << node->src->shape.size() << " dims."; + data_ = std::move(node); } @@ -546,18 +619,22 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto threads = T.thread_bounds->extent; Array args; int ndim = static_cast(src->shape.size()); + + // Build access pointers from regions locally + PrimExpr srcPtr = MakeAccessPtrFromRegion(srcRegion_, 1); + PrimExpr dstPtr = MakeAccessPtrFromRegion(dstRegion_, 2); + if (ndim == 1) { ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim " "= 0."; ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false") << ">::run"; - args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3), - src->shape[0]}; + args = {StringImm(ss.str()), srcPtr, dstPtr, src->shape[0]}; } else if (ndim == 2) { ss << "tl::CumSum2D<" << threads << ", " << dim << ", " << (reverse ? "true" : "false") << ">::run"; - args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3), - src->shape[0], src->shape[1]}; + args = {StringImm(ss.str()), srcPtr, dstPtr, src->shape[0], + src->shape[1]}; } else { LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got " << ndim << "D."; diff --git a/src/op/reduce.h b/src/op/reduce.h index 3b124a4d3..eb0599ebd 100644 --- a/src/op/reduce.h +++ b/src/op/reduce.h @@ -133,8 +133,10 @@ class ReduceOp : public TileOperator { class CumSumOpNode : public TileOperatorNode { public: tir::Buffer src, dst; ///< Source and destination buffers - int dim; ///< Dimension along which to compute cumulative sum - bool reverse; ///< Whether to compute in reverse order + // Optional: keep the original regions used to construct this op + BufferRegion srcRegion_, dstRegion_; + int dim; ///< Dimension along which to compute cumulative sum + bool reverse; ///< Whether to compute in reverse order TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.CumSumOp", CumSumOpNode, TileOperatorNode); @@ -143,6 +145,8 @@ class CumSumOpNode : public TileOperatorNode { refl::ObjectDef() .def_ro("src", &CumSumOpNode::src) .def_ro("dst", &CumSumOpNode::dst) + .def_ro("srcRegion", &CumSumOpNode::srcRegion_) + .def_ro("dstRegion", &CumSumOpNode::dstRegion_) .def_ro("dim", &CumSumOpNode::dim) .def_ro("reverse", &CumSumOpNode::reverse); } diff --git a/testing/python/issue/test_tilelang_issue_1001.py b/testing/python/issue/test_tilelang_issue_1001.py new file mode 100644 index 000000000..77d8cc1f1 --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_1001.py @@ -0,0 +1,33 @@ +import torch +import tilelang +import tilelang.testing +from tilelang import language as T + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + },) +def _cumsum_view_infer_layout(hidden): + num_tokens = T.dynamic('num_tokens') + + @T.prim_func + def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']): + with T.Kernel(num_tokens, threads=128) as pid: + smem = T.alloc_shared((hidden,), dtype='float') + T.copy(x[pid, :], smem) + T.cumsum(T.view(smem, (1, hidden)), dim=1) + + return buggy_kernel + + +def test_cumsum_view_infer_layout(): + hidden = 128 + x = torch.randn(1, hidden, device='cuda', dtype=torch.float) + kernel = _cumsum_view_infer_layout(hidden) + kernel(x) + + +if __name__ == '__main__': + tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_atomic_add.py b/testing/python/language/test_tilelang_language_atomic_add.py index 132e002a9..1287c9ec5 100644 --- a/testing/python/language/test_tilelang_language_atomic_add.py +++ b/testing/python/language/test_tilelang_language_atomic_add.py @@ -260,7 +260,7 @@ def test_atomic_addx2(): run_atomic_addx2(32, 64, 8, 16) -@tilelang.jit(debug_root_path="./testing/python/language") +@tilelang.jit def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"): @T.prim_func diff --git a/tilelang/analysis/__init__.py b/tilelang/analysis/__init__.py index b72fc2ba3..6e5ee5d6c 100644 --- a/tilelang/analysis/__init__.py +++ b/tilelang/analysis/__init__.py @@ -1,3 +1,4 @@ """Tilelang IR analysis & visitors.""" +from .ast_printer import ASTPrinter # noqa: F401 from .nested_loop_checker import NestedLoopChecker # noqa: F401 diff --git a/tilelang/analysis/ast_printer.py b/tilelang/analysis/ast_printer.py new file mode 100644 index 000000000..c54ec5cf9 --- /dev/null +++ b/tilelang/analysis/ast_printer.py @@ -0,0 +1,23 @@ +from tvm import tir +from tvm.tir import PrimFunc +from tvm.tir.transform import prim_func_pass +from tvm.tir.stmt_functor import ir_transform + + +def ASTPrinter(): + """ + Print the AST of a given tilelang module for debugging. + """ + + def pre_visit(statement: tir.Stmt) -> None: + """ + Pre-order visitor to print all visited statements. + """ + + print(f"Visiting statement: {type(statement)}") + + def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc: + new_body = ir_transform(func.body, pre_visit, None) + return func.with_body(new_body) + + return prim_func_pass(pass_fn, opt_level=0) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 35c16a438..f686ba1fb 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -74,6 +74,9 @@ def PreLowerSemanticCheck(mod: IRModule) -> None: Note: This is a validation-only pipeline of passes and does not modify or return the module. """ + # Debug + # tilelang.analysis.ASTPrinter()(mod) + # Check if there are any invalid nested loops. tilelang.analysis.NestedLoopChecker()(mod) diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index 09289559d..6e3f1b689 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -246,8 +246,8 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) - tir.call_intrin( "handle", tir.op.Op.get("tl.cumsum"), - cumsum_smem.access_ptr("r"), - cumsum_smem.access_ptr("w"), + buffer_to_tile_region(cumsum_smem, "r"), + buffer_to_tile_region(cumsum_smem, "w"), dim, reverse, ) @@ -300,8 +300,8 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse return tir.call_intrin( "handle", tir.op.Op.get("tl.cumsum"), - src.access_ptr("r"), - dst.access_ptr("w"), + buffer_to_tile_region(src, "r"), + buffer_to_tile_region(dst, "w"), dim, reverse, )