diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index b90347a3f..3fe921e7d 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -45,15 +45,19 @@ AtomicAdd::AtomicAdd(Array args, Map annotations) { << "AtomicAdd expects at least 2 arguments (src, dst), got " << args.size(); ObjectPtr node = tvm::ffi::make_object(); - Array rgs[2]; - Buffer bf[2]; - for (int i = 0; i < 2; i++) { - auto region = NormalizeToBufferRegion(args[i]); - rgs[i] = region->region; - bf[i] = region->buffer; + + if (IsBufferLikeExpr(args[0])) { + auto region = NormalizeToBufferRegion(args[0]); + node->src = region->buffer; + node->src_range = region->region; + } else { + node->src_value = args[0]; } - std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); - std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); + + auto region = NormalizeToBufferRegion(args[1]); + node->dst = region->buffer; + node->dst_range = region->region; + // Copy annotations from the Call node node->annotations = annotations; data_ = std::move(node); @@ -144,45 +148,49 @@ AtomicAddNode::ReturnIndicesAndSize(int src_dst) const { */ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { Array loop_vars = MakeIterVars(); - bool is_scalar = loop_vars.empty(); - if (is_scalar) { - return For(Var("i"), 0, 1, ForKind::kSerial, - BufferStore(dst, BufferLoad(src, {0}), {0})); - } + ICHECK(!loop_vars.empty()) << "MakeIterVars in AtomicOp should not return " + "empty vars (at least 1 var)"; for (const auto &iv : loop_vars) analyzer->Bind(iv->var, iv->dom); - ICHECK(loop_vars.size() <= src_range.size()) - << "loop_vars.size() = " << loop_vars.size() - << ", src_range.size() = " << src_range.size() << ", src = " << src->name - << ", dst = " << dst->name; - ICHECK(loop_vars.size() <= dst_range.size()) << "loop_vars.size() = " << loop_vars.size() - << ", dst_range.size() = " << dst_range.size() << ", src = " << src->name - << ", dst = " << dst->name; + << ", dst_range.size() = " << dst_range.size() << ", dst = " << dst->name; - Array src_indices = MakeIndices(loop_vars, 0); Array dst_indices = MakeIndices(loop_vars, 1); - Array new_args; // Optional bounds predicates for src and dst - PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0); PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1); - // Load source value and cast to dst dtype if needed - PrimExpr src_value = BufferLoad(src, src_indices); - if (src->dtype != dst->dtype) - src_value = Cast(dst->dtype, src_value); + // Src arg to be passed to the Call atomic operation + PrimExpr src_value_arg; + + // If src is a Buffer + if (!src_value.defined()) { + ICHECK(loop_vars.size() <= src_range.size()) + << "loop_vars.size() = " << loop_vars.size() + << ", src_range.size() = " << src_range.size() + << ", src = " << src->name << ", dst = " << dst->name; + + Array src_indices = MakeIndices(loop_vars, 0); + PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0); + // Load source value + src_value_arg = BufferLoad(src, src_indices); + } else { + src_value_arg = src_value; + } + // Cast to dst dtype if needed + if (src_value_arg->dtype != dst->dtype) + src_value_arg = Cast(dst->dtype, src_value_arg); // Build a pointer to destination element using tvm_access_ptr PrimExpr dst_ptr = Call(DataType::Handle(), builtin::address_of(), {BufferLoad(dst, dst_indices)}); new_args.push_back(dst_ptr); - new_args.push_back(src_value); + new_args.push_back(src_value_arg); new_args.push_back(GetMemoryOrder()); // erase use_tma from annotations diff --git a/src/op/atomic_add.h b/src/op/atomic_add.h index b487fd8e6..ed60f267d 100644 --- a/src/op/atomic_add.h +++ b/src/op/atomic_add.h @@ -37,6 +37,7 @@ class AtomicAddNode : public AtomicOpBaseNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("src", &AtomicAddNode::src) + .def_ro("src_value", &AtomicAddNode::src_value) .def_ro("dst", &AtomicAddNode::dst) .def_ro("src_range", &AtomicAddNode::src_range) .def_ro("dst_range", &AtomicAddNode::dst_range) @@ -47,7 +48,12 @@ class AtomicAddNode : public AtomicOpBaseNode { bool GetUseTMA() const { if (auto val = annotations.Get("use_tma")) { if (auto int_val = val->as()) { - return int_val->value != 0; + if (int_val->value != 0) { + ICHECK(!src_value.defined()) + << "TMA is not supported when using TiledAtomicAdd with PrimExpr " + "as value."; + return true; + } } } return false; diff --git a/src/op/atomic_reduce.cc b/src/op/atomic_reduce.cc index 1bafa6f0a..925572d5f 100644 --- a/src/op/atomic_reduce.cc +++ b/src/op/atomic_reduce.cc @@ -31,15 +31,19 @@ AtomicMax::AtomicMax(Array args, Map annotations) { << "AtomicMax expects at least 2 arguments (src, dst), got " << args.size(); ObjectPtr node = tvm::ffi::make_object(); - Array rgs[2]; - Buffer bf[2]; - for (int i = 0; i < 2; i++) { - auto region = NormalizeToBufferRegion(args[i]); - rgs[i] = region->region; - bf[i] = region->buffer; + + if (IsBufferLikeExpr(args[0])) { + auto region = NormalizeToBufferRegion(args[0]); + node->src = region->buffer; + node->src_range = region->region; + } else { + node->src_value = args[0]; } - std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); - std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); + + auto region = NormalizeToBufferRegion(args[1]); + node->dst = region->buffer; + node->dst_range = region->region; + node->annotations = annotations; data_ = std::move(node); } @@ -63,15 +67,19 @@ AtomicMin::AtomicMin(Array args, Map annotations) { << "AtomicMin expects at least 2 arguments (src, dst), got " << args.size(); ObjectPtr node = tvm::ffi::make_object(); - Array rgs[2]; - Buffer bf[2]; - for (int i = 0; i < 2; i++) { - auto region = NormalizeToBufferRegion(args[i]); - rgs[i] = region->region; - bf[i] = region->buffer; + + if (IsBufferLikeExpr(args[0])) { + auto region = NormalizeToBufferRegion(args[0]); + node->src = region->buffer; + node->src_range = region->region; + } else { + node->src_value = args[0]; } - std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); - std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); + + auto region = NormalizeToBufferRegion(args[1]); + node->dst = region->buffer; + node->dst_range = region->region; + node->annotations = annotations; data_ = std::move(node); } @@ -93,14 +101,23 @@ const Op &AtomicMinNode::GetElemOp() const { return atomic_min_elem_op(); } Array AtomicOpBaseNode::MakeIterVars() const { Array loop_vars; size_t idx = 0; - for (size_t i = 0; i < src_range.size(); i++) { - if (is_one(src_range[i]->extent)) + // Make IterVars according to dst, not src + // Since src may be a scalar Expr + for (size_t i = 0; i < dst_range.size(); i++) { + if (is_one(dst_range[i]->extent)) continue; - Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype); + Var var = Var(std::string{char('i' + idx)}, dst_range[i]->extent->dtype); idx++; loop_vars.push_back( - {Range(0, src_range[i]->extent), var, IterVarType::kDataPar}); + {Range(0, dst_range[i]->extent), var, IterVarType::kDataPar}); + } + + // If is scalar, create a dummy loop var + if (loop_vars.empty()) { + Var var = Var("i"); + loop_vars.push_back({Range(0, 1), var, IterVarType::kDataPar}); } + return loop_vars; } @@ -117,9 +134,11 @@ Array AtomicOpBaseNode::MakeIndices(const Array &ivs, idx++; } } - ICHECK(idx == ivs.size()) - << "idx = " << idx << ", ivs.size() = " << ivs.size() - << "src name = " << src->name << ", dst name = " << dst->name; + + // Special case: scalar range, when there is one var and one range(0, 1) + ICHECK(idx == ivs.size() || (idx == 0 && ivs.size() == 1)) + << "Unmatched indices: idx = " << idx << ", ivs.size() = " << ivs.size() + << ", dst name = " << dst->name; return indices; } @@ -156,41 +175,45 @@ PrimExpr AtomicOpBaseNode::MakePredicate(arith::Analyzer *analyzer, For AtomicOpBaseNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { Array loop_vars = MakeIterVars(); - bool is_scalar = loop_vars.empty(); - if (is_scalar) { - return For(Var("i"), 0, 1, ForKind::kSerial, - BufferStore(dst, BufferLoad(src, {0}), {0})); - } + ICHECK(!loop_vars.empty()) << "MakeIterVars in AtomicOp should not return " + "empty vars (at least 1 var)"; for (const auto &iv : loop_vars) analyzer->Bind(iv->var, iv->dom); - ICHECK(loop_vars.size() <= src_range.size()) - << "loop_vars.size() = " << loop_vars.size() - << ", src_range.size() = " << src_range.size() << ", src = " << src->name - << ", dst = " << dst->name; - ICHECK(loop_vars.size() <= dst_range.size()) << "loop_vars.size() = " << loop_vars.size() - << ", dst_range.size() = " << dst_range.size() << ", src = " << src->name - << ", dst = " << dst->name; + << ", dst_range.size() = " << dst_range.size() << ", dst = " << dst->name; - Array src_indices = MakeIndices(loop_vars, 0); Array dst_indices = MakeIndices(loop_vars, 1); - Array new_args; - // Load source value and cast to dst dtype if needed - PrimExpr src_value = BufferLoad(src, src_indices); - if (src->dtype != dst->dtype) - src_value = Cast(dst->dtype, src_value); + // Src arg to be passed to the Call atomic operation + PrimExpr src_value_arg; + + // If src is a Buffer + if (!src_value.defined()) { + ICHECK(loop_vars.size() <= src_range.size()) + << "loop_vars.size() = " << loop_vars.size() + << ", src_range.size() = " << src_range.size() + << ", src = " << src->name << ", dst = " << dst->name; + + Array src_indices = MakeIndices(loop_vars, 0); + // Load source value + src_value_arg = BufferLoad(src, src_indices); + } else { + src_value_arg = src_value; + } + // Cast to dst dtype if needed + if (src_value_arg->dtype != dst->dtype) + src_value_arg = Cast(dst->dtype, src_value_arg); // Build a pointer to destination element using tvm_access_ptr PrimExpr dst_ptr = Call(DataType::Handle(), builtin::address_of(), {BufferLoad(dst, dst_indices)}); new_args.push_back(dst_ptr); - new_args.push_back(src_value); + new_args.push_back(src_value_arg); new_args.push_back(GetMemoryOrder()); // Use the appropriate elem_op based on the derived type (via virtual call) diff --git a/src/op/atomic_reduce.h b/src/op/atomic_reduce.h index bdfb12ca8..57b13e139 100644 --- a/src/op/atomic_reduce.h +++ b/src/op/atomic_reduce.h @@ -23,7 +23,8 @@ using namespace tir; */ class AtomicOpBaseNode : public TileOperatorNode { public: - Buffer src, dst; ///< Source and destination buffers + PrimExpr src_value; ///< Source values, for cases src is not a buffer + Buffer src, dst; ///< Source and destination buffers Array src_range, dst_range; ///< Access ranges for source and destination Map annotations; ///< Annotations for the atomic operation @@ -81,6 +82,7 @@ class AtomicMaxNode : public AtomicOpBaseNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("src", &AtomicMaxNode::src) + .def_ro("src_value", &AtomicMaxNode::src_value) .def_ro("dst", &AtomicMaxNode::dst) .def_ro("src_range", &AtomicMaxNode::src_range) .def_ro("dst_range", &AtomicMaxNode::dst_range) @@ -113,6 +115,7 @@ class AtomicMinNode : public AtomicOpBaseNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("src", &AtomicMinNode::src) + .def_ro("src_value", &AtomicMinNode::src_value) .def_ro("dst", &AtomicMinNode::dst) .def_ro("src_range", &AtomicMinNode::src_range) .def_ro("dst_range", &AtomicMinNode::dst_range) diff --git a/src/op/utils.cc b/src/op/utils.cc index 042c38a0c..7f8c3c7c6 100644 --- a/src/op/utils.cc +++ b/src/op/utils.cc @@ -4,6 +4,7 @@ */ #include "utils.h" +#include "tvm/tir/expr.h" #include @@ -12,6 +13,16 @@ namespace tl { using namespace tir; +bool IsBufferLikeExpr(const PrimExpr &expr) { + if (expr.as() || expr.as()) { + return true; + } + if (const auto *call = expr.as()) { + return (call->op.same_as(RegionOp::Get())); + } + return false; +} + BufferRegion NormalizeToBufferRegion(const PrimExpr &arg) { // Case 1: Already a BufferRegion if (arg->IsInstance()) { diff --git a/src/op/utils.h b/src/op/utils.h index fcbfee9e2..9fdb3b4af 100644 --- a/src/op/utils.h +++ b/src/op/utils.h @@ -9,6 +9,7 @@ #include "../target/stubs/cuda.h" #include "./operator.h" #include "region.h" +#include "tvm/runtime/base.h" #include #include @@ -25,6 +26,10 @@ template Array ReverseArray(Array array) { return Array{array.rbegin(), array.rend()}; } +// Check if an PrimExpr is a buffer-like (BufferRegion/BufferLoad/tl.region) +// expression. +TVM_DLL bool IsBufferLikeExpr(const PrimExpr &expr); + // Normalize an argument (BufferRegion/BufferLoad/tl.region) // to BufferRegion so ops can uniformly consume regions. // Note: tvm_access_ptr is no longer supported here. diff --git a/src/transform/vectorize_loop.cc b/src/transform/vectorize_loop.cc index e494fc3b1..301ecef4e 100644 --- a/src/transform/vectorize_loop.cc +++ b/src/transform/vectorize_loop.cc @@ -470,6 +470,7 @@ class TLVectorizer : public StmtMutator, return std::move(var); } } + // IfThenElse expr PrimExpr MutateIfThenElseExpr_(const CallNode *op) { PrimExpr cond = this->VisitExpr(op->args[0]); @@ -498,6 +499,7 @@ class TLVectorizer : public StmtMutator, } } } + // Address of: remove vectorized var from indices to get base address // e.g., T.address_of(buf[base + vec]) -> T.address_of(buf[base]) PrimExpr MutateAddressOfCall_(const CallNode *op) { @@ -524,6 +526,7 @@ class TLVectorizer : public StmtMutator, return Call(op->dtype, op->op, {new_load}); } + // Reinterpret expr PrimExpr MutateReinterpretExpr_(const CallNode *op) { ICHECK(op->op.same_as(builtin::reinterpret())); @@ -556,8 +559,16 @@ class TLVectorizer : public StmtMutator, return tvm::ffi::GetRef(op); } int vector_size = static_cast(*lanes_ptr); + auto dst = VisitExpr(op->args[0]); auto src = VisitExpr(op->args[1]); + + // If src is not Ramp/Broadcasted, it must be a scalar or something. + // Broadcast to vector size if needed + if (src.same_as(op->args[1])) { + src = BroadcastTo(src, vector_size, src.dtype().is_scalable_vector()); + } + // Check if dtype supports this vector size auto dst_buffer_load = ExtractBufferLoadForAtomic(dst); Target target = Target::Current(false); @@ -571,6 +582,7 @@ class TLVectorizer : public StmtMutator, // Return the vectorized atomic op return Call(op->dtype, GetVectorizedAtomicOp(vector_size), {dst, src}); } + // Call PrimExpr VisitExpr_(const CallNode *op) final { if (op->op.same_as(builtin::if_then_else())) { @@ -629,6 +641,7 @@ class TLVectorizer : public StmtMutator, } } } + // BufferLoad PrimExpr VisitExpr_(const BufferLoadNode *op) final { auto load = tvm::ffi::GetRef(op); @@ -646,6 +659,7 @@ class TLVectorizer : public StmtMutator, return std::move(load); } + // Let PrimExpr VisitExpr_(const LetNode *op) final { PrimExpr value = this->VisitExpr(op->value); @@ -677,6 +691,7 @@ class TLVectorizer : public StmtMutator, } } } + // BufferStore Stmt VisitStmt_(const BufferStoreNode *op) final { auto store = tvm::ffi::GetRef(op); @@ -733,6 +748,7 @@ class TLVectorizer : public StmtMutator, return std::move(store); } + // For Stmt VisitStmt_(const ForNode *op) final { if (op->kind == ForKind::kVectorized) { @@ -752,6 +768,7 @@ class TLVectorizer : public StmtMutator, op->thread_binding, op->annotations); } } + // IfThenElse Stmt VisitStmt_(const IfThenElseNode *op) final { ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector()); @@ -771,10 +788,12 @@ class TLVectorizer : public StmtMutator, return IfThenElse(condition, then_case, else_case); } } + // While Stmt VisitStmt_(const WhileNode *op) final { LOG(FATAL) << "A while loop inside a vectorized loop not supported."; } + // LetStmt Stmt VisitStmt_(const LetStmtNode *op) final { PrimExpr value = this->VisitExpr(op->value); diff --git a/testing/python/language/test_tilelang_language_atomic.py b/testing/python/language/test_tilelang_language_atomic.py index 44ee97e49..ebb61f4c3 100644 --- a/testing/python/language/test_tilelang_language_atomic.py +++ b/testing/python/language/test_tilelang_language_atomic.py @@ -4,6 +4,9 @@ import torch +# ======================= Thread-level atomic add ======================= + + @tilelang.jit def atomic_add_program(K, M, N, block_M, block_N, dtype=T.float32): @T.prim_func @@ -21,40 +24,6 @@ def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): def run_atomic_add(K, M, N, block_M, block_N, dtype=T.float32): kernel = atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) - print(kernel.get_kernel_source()) - import torch - - def ref_program(A, B): - for k in range(K): - for i in range(M): - for j in range(N): - B[i, j] += A[k, i, j] - - A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() - B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() - ref_B = B.clone() - ref_program(A, ref_B) - kernel(A, B) - torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) - - -@tilelang.jit -def tile_atomic_add_program(K, M, N, block_M, block_N, dtype=T.float32): - @T.prim_func - def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): - with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): - A_shared = T.alloc_shared((block_M, block_N), dtype) - - T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) - - T.atomic_add(B[bx * block_M, by * block_N], A_shared) - - return atomic_add - - -def run_tile_atomic_add(K, M, N, block_M, block_N, dtype=T.float32): - kernel = tile_atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) - print(kernel.get_kernel_source()) import torch def ref_program(A, B): @@ -68,102 +37,9 @@ def ref_program(A, B): ref_B = B.clone() ref_program(A, ref_B) kernel(A, B) - print(B) - print(ref_B) torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) -@tilelang.jit -def atomic_max_program(K, M, N, block_M, block_N, dtype=T.float32): - @T.prim_func - def atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): - with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): - A_shared = T.alloc_shared((block_M, block_N), dtype) - - T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) - - for i, j in T.Parallel(block_M, block_N): - T.atomic_max(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) - - return atomic_max - - -def run_atomic_max(K, M, N, block_M, block_N, dtype=T.float32): - kernel = atomic_max_program(K, M, N, block_M, block_N, dtype=dtype) - import torch - - def ref_program(A, B): - for k in range(K): - for i in range(M): - for j in range(N): - B[i, j] = max(B[i, j], A[k, i, j]) - - A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() - B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() - ref_B = B.clone() - ref_program(A, ref_B) - kernel(A, B) - torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) - - -@tilelang.jit -def atomic_min_program(K, M, N, block_M, block_N, dtype=T.float32): - @T.prim_func - def atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): - with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): - A_shared = T.alloc_shared((block_M, block_N), dtype) - - T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) - - for i, j in T.Parallel(block_M, block_N): - T.atomic_min(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) - - return atomic_min - - -def run_atomic_min(K, M, N, block_M, block_N, dtype=T.float32): - kernel = atomic_min_program(K, M, N, block_M, block_N, dtype=dtype) - import torch - - def ref_program(A, B): - for k in range(K): - for i in range(M): - for j in range(N): - B[i, j] = min(B[i, j], A[k, i, j]) - - A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() - B = torch.full((M, N), float("inf"), dtype=getattr(torch, dtype)).cuda() - ref_B = B.clone() - ref_program(A, ref_B) - kernel(A, B) - torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) - - -@tilelang.jit -def atomic_load_store_program(M, N, block_M, block_N, dtype=T.float32): - @T.prim_func - def atomic_load_store(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): - with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): - for i, j in T.Parallel(block_M, block_N): - idx_i = bx * block_M + i - idx_j = by * block_N + j - if idx_i < M and idx_j < N: - val = T.atomic_load(A[idx_i, idx_j]) - T.atomic_store(B[idx_i, idx_j], val) - - return atomic_load_store - - -def run_atomic_load_store(M, N, block_M, block_N, dtype=T.float32): - kernel = atomic_load_store_program(M, N, block_M, block_N, dtype=dtype) - import torch - - A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda() - B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() - kernel(A, B) - torch.testing.assert_close(B, A, atol=1e-3, rtol=1e-3) - - @tilelang.jit def atomic_memory_order_program(K, M, N, block_M, block_N, dtype=T.float32): @T.prim_func @@ -226,37 +102,6 @@ def run_atomic_addx2(M, N, block_M, block_N, dtype=T.float16): torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) -def test_atomic_add(): - run_atomic_add(8, 128, 128, 32, 32) - - -def test_atomic_max(): - run_atomic_max(4, 64, 64, 16, 16) - - -def test_atomic_min(): - run_atomic_min(4, 64, 64, 16, 16) - - -@tilelang.testing.requires_cuda -def test_atomic_load_store(): - run_atomic_load_store(64, 64, 16, 16) - - -@tilelang.testing.requires_cuda -def test_atomic_memory_order(): - run_atomic_memory_order(4, 64, 64, 16, 16) - - -@tilelang.testing.requires_cuda -def test_atomic_addx2_half(): - run_atomic_addx2(32, 64, 8, 16, dtype=T.float16) - - -def test_atomic_addx2_float(): - run_atomic_addx2(32, 64, 8, 16, dtype=T.float32) - - @tilelang.jit def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype=T.float32): @T.prim_func @@ -368,6 +213,62 @@ def tma_atomic_add_program(out, explicit_swizzle=False): T.atomic_add(out, out_shared, use_tma=True) +@tilelang.testing.requires_cuda +def test_tma_atomic_add(): + out = torch.zeros((16, 16), dtype=torch.float32, device="cuda") + tma_atomic_add_program(out) + torch.testing.assert_close(out, torch.ones((16, 16), dtype=torch.float32, device="cuda") * 16) + + kernel = tma_atomic_add_program.compile(out=T.Tensor[(16, 16), T.float32]) + assert "tma_store_add" in kernel.get_kernel_source() + assert "desc" in kernel.get_kernel_source() # Ensure using cp.reduce.async.bulk.tensor + + kernel_with_explicit_swizzle = tma_atomic_add_program.compile(out=T.Tensor[(16, 16), T.float32], explicit_swizzle=True) + # Ensure auto swizzled layout is applied + assert kernel.get_kernel_source() == kernel_with_explicit_swizzle.get_kernel_source() + + +def run_atomic_add_auto_vectorized(K, M, N, block_M, block_N, dtype=T.float32): + kernel = atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) + assert "AtomicAddx4" in kernel.get_kernel_source() + + +@tilelang.jit +def atomic_add_complicated_parallel_program(K, M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) + + for i, j in T.Parallel(block_M, block_N): + value = A_shared[i, j] + T.atomic_add(B[bx * block_M + i, by * block_N + j], value) + + return atomic_add + + +def run_atomic_add_complicated_parallel(K, M, N, block_M, block_N, dtype=T.float32): + kernel = atomic_add_complicated_parallel_program(K, M, N, block_M, block_N, dtype=dtype) + assert "float4 value" in kernel.get_kernel_source() + assert "AtomicAddx4" in kernel.get_kernel_source() + + +@tilelang.testing.requires_cuda +def test_atomic_memory_order(): + run_atomic_memory_order(4, 64, 64, 16, 16) + + +@tilelang.testing.requires_cuda +def test_atomic_addx2_half(): + run_atomic_addx2(32, 64, 8, 16, dtype=T.float16) + + +def test_atomic_addx2_float(): + run_atomic_addx2(32, 64, 8, 16, dtype=T.float32) + + @tilelang.testing.requires_cuda def test_atomic_different_memory_orders(): run_atomic_different_memory_orders(32, 32, 8, 8, dtype=T.float32) @@ -383,11 +284,235 @@ def test_atomic_return_prev(): run_atomic_return_prev(32, 32, 8, 8) +def test_atomic_add(): + run_atomic_add(8, 128, 128, 32, 32) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_atomic_add_auto_vectorized(): + run_atomic_add_auto_vectorized(8, 128, 128, 32, 32) + + +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_atomic_add_complicated_parallel(): + run_atomic_add_complicated_parallel(8, 128, 128, 32, 32) + + +# ======================= Tile-level atomic add ======================= + + +@tilelang.jit +def tile_atomic_add_program(K, M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) + + T.atomic_add(B[bx * block_M, by * block_N], A_shared) + + return atomic_add + + +def run_tile_atomic_add(K, M, N, block_M, block_N, dtype=T.float32): + kernel = tile_atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) + import torch + + def ref_program(A, B): + for k in range(K): + for i in range(M): + for j in range(N): + B[i, j] += A[k, i, j] + + A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() + B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() + ref_B = B.clone() + ref_program(A, ref_B) + kernel(A, B) + torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) + + +@tilelang.jit +def tile_atomic_add_expr_program(M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_add(A: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): + T.atomic_add(A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], 1.0) + + return atomic_add + + +def run_tile_atomic_add_expr(M, N, block_M, block_N, dtype=T.float32): + kernel = tile_atomic_add_expr_program(M, N, block_M, block_N, dtype=dtype) + import torch + + def ref_program(A): + for i in range(M): + for j in range(N): + A[i, j] += 1 + + A = torch.zeros(M, N, dtype=torch.float32).cuda() + ref_A = A.clone() + ref_program(ref_A) + kernel(A) + torch.testing.assert_close(A, ref_A, atol=1e-3, rtol=1e-3) + + +@tilelang.jit +def tile_atomic_add_scalar_program(dtype=T.float32): + @T.prim_func + def atomic_add(A: T.Tensor((1), dtype), B: T.Tensor((1), dtype)): + with T.Kernel( + 1, + ) as _: + A_local = T.alloc_local([1], dtype) + T.copy(A, A_local) + T.clear(B) + T.atomic_add(B, A_local) + T.atomic_add(B, 1) + + return atomic_add + + +def run_tile_atomic_add_scalar(dtype=T.float32): + kernel = tile_atomic_add_scalar_program(dtype=dtype) + import torch + + def ref_program(A, B): + B[0] = A[0] + 1 + + A = torch.randn(1, dtype=getattr(torch, dtype)).cuda() + B = torch.zeros(1, dtype=getattr(torch, dtype)).cuda() + ref_B = B.clone() + ref_program(A, ref_B) + kernel(A, B) + torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) + + def test_tile_atomic_add(): run_tile_atomic_add(8, 128, 128, 32, 32) -# ======================= Tile-level atomic max ======================= +def test_tile_atomic_add_expr(): + run_tile_atomic_add_expr(128, 128, 32, 32) + + +def test_tile_atomic_add_scalar(): + run_tile_atomic_add_scalar() + + +# ======================= Thread-level atomic max/min/load store ======================= + + +@tilelang.jit +def atomic_max_program(K, M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) + + for i, j in T.Parallel(block_M, block_N): + T.atomic_max(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) + + return atomic_max + + +def run_atomic_max(K, M, N, block_M, block_N, dtype=T.float32): + kernel = atomic_max_program(K, M, N, block_M, block_N, dtype=dtype) + import torch + + def ref_program(A, B): + for k in range(K): + for i in range(M): + for j in range(N): + B[i, j] = max(B[i, j], A[k, i, j]) + + A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() + B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() + ref_B = B.clone() + ref_program(A, ref_B) + kernel(A, B) + torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) + + +@tilelang.jit +def atomic_min_program(K, M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) + + for i, j in T.Parallel(block_M, block_N): + T.atomic_min(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) + + return atomic_min + + +def run_atomic_min(K, M, N, block_M, block_N, dtype=T.float32): + kernel = atomic_min_program(K, M, N, block_M, block_N, dtype=dtype) + import torch + + def ref_program(A, B): + for k in range(K): + for i in range(M): + for j in range(N): + B[i, j] = min(B[i, j], A[k, i, j]) + + A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() + B = torch.full((M, N), float("inf"), dtype=getattr(torch, dtype)).cuda() + ref_B = B.clone() + ref_program(A, ref_B) + kernel(A, B) + torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) + + +@tilelang.jit +def atomic_load_store_program(M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_load_store(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + idx_i = bx * block_M + i + idx_j = by * block_N + j + if idx_i < M and idx_j < N: + val = T.atomic_load(A[idx_i, idx_j]) + T.atomic_store(B[idx_i, idx_j], val) + + return atomic_load_store + + +def run_atomic_load_store(M, N, block_M, block_N, dtype=T.float32): + kernel = atomic_load_store_program(M, N, block_M, block_N, dtype=dtype) + import torch + + A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda() + B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() + kernel(A, B) + torch.testing.assert_close(B, A, atol=1e-3, rtol=1e-3) + + +def test_atomic_max(): + run_atomic_max(4, 64, 64, 16, 16) + + +def test_atomic_min(): + run_atomic_min(4, 64, 64, 16, 16) + + +@tilelang.testing.requires_cuda +def test_atomic_load_store(): + run_atomic_load_store(64, 64, 16, 16) + + +# ======================= Tile-level atomic max/min ======================= + + @tilelang.jit def tile_atomic_max_program(K, M, N, block_M, block_N, dtype=T.float32): @T.prim_func @@ -404,7 +529,6 @@ def tile_atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): def run_tile_atomic_max(K, M, N, block_M, block_N, dtype=T.float32): kernel = tile_atomic_max_program(K, M, N, block_M, block_N, dtype=dtype) - print(kernel.get_kernel_source()) def ref_program(A, B): for k in range(K): @@ -420,11 +544,6 @@ def ref_program(A, B): torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) -def test_tile_atomic_max(): - run_tile_atomic_max(8, 128, 128, 32, 32) - - -# ======================= Tile-level atomic min ======================= @tilelang.jit def tile_atomic_min_program(K, M, N, block_M, block_N, dtype=T.float32): @T.prim_func @@ -441,7 +560,6 @@ def tile_atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): def run_tile_atomic_min(K, M, N, block_M, block_N, dtype=T.float32): kernel = tile_atomic_min_program(K, M, N, block_M, block_N, dtype=dtype) - print(kernel.get_kernel_source()) def ref_program(A, B): for k in range(K): @@ -457,61 +575,42 @@ def ref_program(A, B): torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) -def test_tile_atomic_min(): - run_tile_atomic_min(8, 128, 128, 32, 32) - - -@tilelang.testing.requires_cuda -def test_tma_atomic_add(): - out = torch.zeros((16, 16), dtype=torch.float32, device="cuda") - tma_atomic_add_program(out) - torch.testing.assert_close(out, torch.ones((16, 16), dtype=torch.float32, device="cuda") * 16) - - kernel = tma_atomic_add_program.compile(out=T.Tensor[(16, 16), T.float32]) - assert "tma_store_add" in kernel.get_kernel_source() - assert "desc" in kernel.get_kernel_source() # Ensure using cp.reduce.async.bulk.tensor - - kernel_with_explicit_swizzle = tma_atomic_add_program.compile(out=T.Tensor[(16, 16), T.float32], explicit_swizzle=True) - # Ensure auto swizzled layout is applied - assert kernel.get_kernel_source() == kernel_with_explicit_swizzle.get_kernel_source() - +@tilelang.jit +def tile_atomic_max_expr_program(M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_max(A: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): + T.atomic_max(A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], 0.5) -def run_atomic_add_auto_vectorized(K, M, N, block_M, block_N, dtype=T.float32): - kernel = atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) - assert "AtomicAddx4" in kernel.get_kernel_source() + return atomic_max -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_atomic_add_auto_vectorized(): - run_atomic_add_auto_vectorized(8, 128, 128, 32, 32) - +def run_tile_atomic_max_expr(M, N, block_M, block_N, dtype=T.float32): + kernel = tile_atomic_max_expr_program(M, N, block_M, block_N, dtype=dtype) + import torch -@tilelang.jit -def atomic_add_complicated_parallel_program(K, M, N, block_M, block_N, dtype=T.float32): - @T.prim_func - def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): - with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): - A_shared = T.alloc_shared((block_M, block_N), dtype) + def ref_program(A): + for i in range(M): + for j in range(N): + A[i, j] = max(A[i, j], 0.5) - T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) + A = torch.randn(M, N, dtype=torch.float32).cuda() + ref_A = A.clone() + ref_program(ref_A) + kernel(A) + torch.testing.assert_close(A, ref_A, atol=1e-3, rtol=1e-3) - for i, j in T.Parallel(block_M, block_N): - value = A_shared[i, j] - T.atomic_add(B[bx * block_M + i, by * block_N + j], value) - return atomic_add +def test_tile_atomic_max(): + run_tile_atomic_max(8, 128, 128, 32, 32) -def run_atomic_add_complicated_parallel(K, M, N, block_M, block_N, dtype=T.float32): - kernel = atomic_add_complicated_parallel_program(K, M, N, block_M, block_N, dtype=dtype) - assert "float4 value" in kernel.get_kernel_source() - assert "AtomicAddx4" in kernel.get_kernel_source() +def test_tile_atomic_min(): + run_tile_atomic_min(8, 128, 128, 32, 32) -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_atomic_add_complicated_parallel(): - run_atomic_add_complicated_parallel(8, 128, 128, 32, 32) +def test_tile_atomic_max_expr(): + run_tile_atomic_max_expr(128, 128, 32, 32) if __name__ == "__main__": diff --git a/tilelang/_typing.py b/tilelang/_typing.py index e834d7916..41803657a 100644 --- a/tilelang/_typing.py +++ b/tilelang/_typing.py @@ -35,3 +35,7 @@ # adapted to string. DType: TypeAlias = Union[dtype, ir.Type, str, type] ShapeType: TypeAlias = Union[list[Union[tir.PrimExpr, int]], tuple[Union[tir.PrimExpr, int], ...]] + +# PrimExpr with adaptation to Python basic data types +# IntImm, FloatImm, Bool: IntImm, Integer: IntImm +PyPrimExpr: TypeAlias = tir.PrimExpr | int | float | bool diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py index f60b9a924..fa65b0044 100644 --- a/tilelang/intrinsics/mfma_macro_generator.py +++ b/tilelang/intrinsics/mfma_macro_generator.py @@ -10,7 +10,7 @@ from typing import Literal, Callable from tilelang.utils import is_fragment -from tilelang.utils.language import get_buffer_region_from_load +from tilelang.language.utils import get_buffer_region_from_load from .mfma_layout import ( shared_16x4_to_local_64x1_layout_A, shared_4x16_to_local_64x1_layout_B, diff --git a/tilelang/language/atomic.py b/tilelang/language/atomic.py index 9144b36c1..b1521b7d0 100644 --- a/tilelang/language/atomic.py +++ b/tilelang/language/atomic.py @@ -4,8 +4,9 @@ import tilelang.language as T from tvm import ir -from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op +from tvm.tir import PrimExpr, Buffer, op from tilelang.utils.language import to_buffer_region, legalize_pairwise_extents +from tilelang.language.utils import get_extent _MEMORY_ORDER_ID_MAP = { "relaxed": 0, @@ -57,16 +58,6 @@ def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None, re >>> atomic_max(dst_tensor, src_tensor) # Max entire tensors atomically """ - def get_extent(data): - if isinstance(data, Var) and T.has_let_value(data): - data = T.get_let_value(data) - if isinstance(data, Buffer): - return data.shape - elif isinstance(data, BufferRegion): - return [x.extent for x in data.region] - else: - return None - src_extent = get_extent(value) dst_extent = get_extent(dst) @@ -83,15 +74,20 @@ def get_extent(data): memory_order_id, ) + # When both arguments are Buffer, we can check whether they are structural equal. if isinstance(dst, Buffer) and isinstance(value, Buffer): ir.assert_structural_equal(dst.shape, value.shape) assert src_extent or dst_extent, "Can't deduce atomicmax extents from args" + + # If src is BufferLike, we need to first transform it to region + if src_extent: + value = to_buffer_region(value, access_type="r", extents=src_extent) + src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent) - value = to_buffer_region(value, access_type="r", extents=src_extent) dst = to_buffer_region(dst, access_type="w", extents=dst_extent) if return_prev: @@ -144,16 +140,6 @@ def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None, re >>> atomic_min(dst_tensor, src_tensor) # Min entire tensors atomically """ - def get_extent(data): - if isinstance(data, Var) and T.has_let_value(data): - data = T.get_let_value(data) - if isinstance(data, Buffer): - return data.shape - elif isinstance(data, BufferRegion): - return [x.extent for x in data.region] - else: - return None - src_extent = get_extent(value) dst_extent = get_extent(dst) @@ -170,15 +156,20 @@ def get_extent(data): memory_order_id, ) + # When both arguments are Buffer, we can check whether they are structural equal. if isinstance(dst, Buffer) and isinstance(value, Buffer): ir.assert_structural_equal(dst.shape, value.shape) assert src_extent or dst_extent, "Can't deduce atomicmin extents from args" + + # If src is BufferLike, we need to first transform it to region + if src_extent: + value = to_buffer_region(value, access_type="r", extents=src_extent) + src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent) - value = to_buffer_region(value, access_type="r", extents=src_extent) dst = to_buffer_region(dst, access_type="w", extents=dst_extent) if return_prev: @@ -236,29 +227,10 @@ def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None, re >>> atomic_add(global_grad, gradients) """ - def get_extent(data): - """ - Return the inferred extent (shape) of a buffer-like object. - - If `data` is a Var bound to a let value, the let value is resolved before inspection. - Parameters: - data: A Var, Buffer, or BufferRegion to inspect. - - Returns: - The shape/extents as a list-like of PrimExpr (Buffer.shape or list of region item extents), or None if the extent cannot be determined. - """ - if isinstance(data, Var) and T.has_let_value(data): - data = T.get_let_value(data) - if isinstance(data, Buffer): - return data.shape - elif isinstance(data, BufferRegion): - return [x.extent for x in data.region] - else: - return None - src_extent = get_extent(value) dst_extent = get_extent(dst) + # Thread-level atomic add, where both extent can't be inferred if dst_extent is None and src_extent is None: atomic_add_op = op.Op.get("tl.atomic_add_ret_elem_op") if return_prev else op.Op.get("tl.atomic_add_elem_op") return_type = dst.dtype if return_prev else "handle" @@ -275,15 +247,20 @@ def get_extent(data): _MEMORY_ORDER_ID_MAP[memory_order], ) + # When both arguments are Buffer, we can check whether they are structural equal. if isinstance(dst, Buffer) and isinstance(value, Buffer): ir.assert_structural_equal(dst.shape, value.shape) assert src_extent or dst_extent, "Can't deduce atomicadd extents from args" + + # If src is BufferLike, we need to first transform it to region + if src_extent: + value = to_buffer_region(value, access_type="r", extents=src_extent) + src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent) - value = to_buffer_region(value, access_type="r", extents=src_extent) dst = to_buffer_region(dst, access_type="w", extents=dst_extent) # Note: tile-region-based atomic operations don't support return_prev yet diff --git a/tilelang/language/copy_op.py b/tilelang/language/copy_op.py index ab158f14d..fc69b3bbc 100644 --- a/tilelang/language/copy_op.py +++ b/tilelang/language/copy_op.py @@ -3,12 +3,11 @@ from __future__ import annotations from typing import Literal, Any from tilelang._typing import BufferLikeType -from tilelang import language as T from tilelang.utils.language import ( to_buffer_region, - get_buffer_region_from_load, legalize_pairwise_extents, ) +from tilelang.language.utils import get_extent from tvm import ir, tir @@ -62,21 +61,6 @@ def copy( if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer): ir.assert_structural_equal(src.shape, dst.shape) - def get_extent(data): - if isinstance(data, tir.Var) and T.has_let_value(data): - data = T.get_let_value(data) - if isinstance(data, tir.Buffer): - return data.shape - elif isinstance(data, tir.BufferRegion): - return [x.extent for x in data.region] - elif isinstance(data, tir.BufferLoad): - region = get_buffer_region_from_load(data) - if region is None: - return None - return [x.extent for x in region.region] - else: - return None - src_extent = get_extent(src) dst_extent = get_extent(dst) # Combine the nested if statements into a single if statement as suggested by SIM102 diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index 139bbf8a1..2e1df2d23 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -1,7 +1,12 @@ +"""Utils in TileLang operators.""" + +from __future__ import annotations + from tilelang import tvm as tvm -from tvm import tir +from tvm import ir, tir from tvm.tir import PrimExpr, BufferLoad, op from tilelang import language as T +from tilelang._typing import BufferLikeType, ShapeType def region(buffer: BufferLoad, access_type: str, *args: PrimExpr) -> PrimExpr: @@ -90,3 +95,67 @@ def linear_index(*args: PrimExpr) -> PrimExpr: for idx, stride in zip(coords[1:], strides): linear = linear * stride + idx return linear + + +def get_buffer_region_from_load(buffer_load: tir.BufferLoad, extents: list[PrimExpr] | None = None) -> tir.BufferRegion | None: + """ + Get the buffer region from a buffer load. + + May encounter buffer load like C[0:128, 0:32], ref to pull request + for buffer wise op: https://github.com/apache/tvm/pull/14693 + convert load to region. + + If the buffer load has ramp indices, we will use the ramp's base and lanes to create the region. + Otherwise, return None since the load cannot be converted to a region. + """ + buffer, indices = buffer_load.buffer, buffer_load.indices + regions = [] + found_ramp: bool = False + + if extents is not None: + assert len(extents) == len(indices), "extents should have the same length as indices" + for i, indice in enumerate(indices): + if isinstance(indice, tir.Ramp): + assert extents is None, "extents should be provided for BufferLoad with Ramp indices" + regions.append(ir.Range.from_min_extent(indice.base, indice.lanes)) + found_ramp = True + elif isinstance(indice, tir.PrimExpr): + if extents is not None: + regions.append(ir.Range.from_min_extent(indice, extents[i])) + found_ramp = True + else: + regions.append(ir.Range.from_min_extent(indice, 1)) + else: + raise ValueError(f"Unsupported type: {type(indice)} for index {i}") + if found_ramp: + return tir.BufferRegion(buffer, regions) + else: + # NOTE(chaofan): Or we can return a region with extent 1? + return None + + +def get_extent(data: BufferLikeType) -> ShapeType | None: + """Return the inferred extent (shape) of a buffer-like object. + + If `data` is a Var bound to a let value, the let value is resolved before inspection. + + Parameters: + data: A Var, Buffer, BufferLoad or BufferRegion to inspect. + + Returns: + The shape/extents as a list-like of PrimExpr (Buffer.shape or list of region item extents), or None if the extent cannot be determined. + """ + + if isinstance(data, tir.Var) and T.has_let_value(data): + data = T.get_let_value(data) + if isinstance(data, tir.Buffer): + return data.shape + elif isinstance(data, tir.BufferRegion): + return [x.extent for x in data.region] + elif isinstance(data, tir.BufferLoad): + region = get_buffer_region_from_load(data) + if region is None: + return None + return [x.extent for x in region.region] + else: + return None diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index bc90679a1..c0bcb6209 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -2,6 +2,7 @@ from tilelang._typing import BufferLikeType from tvm.tir import Buffer, BufferLoad, BufferRegion, PrimExpr from tilelang.language.utils import region as _make_region_call +from tilelang.language.utils import get_buffer_region_from_load from functools import reduce from tvm import IRModule, DataType from tvm.tir import PrimFunc @@ -173,39 +174,6 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc: return func -def get_buffer_region_from_load(buffer_load: tir.BufferLoad, extents: list[PrimExpr] | None = None) -> tir.BufferRegion | None: - """ - Get the buffer region from a buffer load. - - May encounter buffer load like C[0:128, 0:32], ref to pull request - for buffer wise op: https://github.com/apache/tvm/pull/14693 - convert load to region - """ - buffer, indices = buffer_load.buffer, buffer_load.indices - regions = [] - found_ramp: bool = False - - if extents is not None: - assert len(extents) == len(indices), "extents should have the same length as indices" - for i, indice in enumerate(indices): - if isinstance(indice, tir.Ramp): - assert extents is None, "extents should be provided for BufferLoad with Ramp indices" - regions.append(ir.Range.from_min_extent(indice.base, indice.lanes)) - found_ramp = True - elif isinstance(indice, tir.PrimExpr): - if extents is not None: - regions.append(ir.Range.from_min_extent(indice, extents[i])) - found_ramp = True - else: - regions.append(ir.Range.from_min_extent(indice, 1)) - else: - raise ValueError(f"Unsupported type: {type(indice)} for index {i}") - if found_ramp: - return tir.BufferRegion(buffer, regions) - else: - return None - - def to_buffer_region(obj: BufferLikeType, access_type: str = "rw", extents: list[PrimExpr] | None = None) -> PrimExpr | BufferRegion: """ Convert to/from the tl.region representation.