diff --git a/examples/gemm_sm100/gemm_tcgen5mma_ws.py b/examples/gemm_sm100/gemm_tcgen5mma_ws.py new file mode 100644 index 000000000..fd4f7ac32 --- /dev/null +++ b/examples/gemm_sm100/gemm_tcgen5mma_ws.py @@ -0,0 +1,91 @@ +# Non-persistent, 1-SM GEMM + +import torch +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench + + +@tilelang.jit +def gemm(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages, use_tma_store=True): + M, N, K = T.const("M, N, K") + + k_iters = T.ceildiv(K, block_K) + + A: T.Tensor[[M, K], in_dtype] + B: T.Tensor[[K, N], in_dtype] + C = T.empty((M, N), out_dtype) + + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) + loaded = T.alloc_barrier([32] * num_stages) + consumed = T.alloc_barrier([1] * num_stages) + tmem_full = T.alloc_barrier([1]) + + tx = T.get_thread_binding() + + T.use_swizzle(8) + + if tx < 32: # warp 0: issue tma + for k in T.serial(k_iters): + T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1) + T.copy(A[by * block_M : (by + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared[k % num_stages, :, :]) + T.copy(B[k * block_K : (k + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[k % num_stages, :, :]) + T.mbarrier_arrive(loaded[k % num_stages]) + elif tx < 64: # warp 1: issue tcgen5 + for k in T.serial(k_iters): + T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1) + T.gemm( + A_shared[k % num_stages, :, :], + B_shared[k % num_stages, :, :], + C_tmem, + mbar=consumed[k % num_stages], + wg_wait=-1, + clear_accum=k == 0, + ) + T.tcgen05_mma_arrive(tmem_full) + + # Wait for all tcgen5 to finish + T.mbarrier_wait_parity(tmem_full, 0) + + T.sync_threads() # TileLang won't generate this if not annotated + T.copy(C_tmem, C_local) + if use_tma_store: + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + else: + T.copy(C_local, C_local_cast) + T.copy(C_local_cast, C[by * block_M, bx * block_N]) + return C + + +def main(): + M, N, K = 8192, 8192, 8192 + block_M, block_N, block_K = 128, 256, 64 + in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float + num_stages = 4 + + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + c = gemm(a, b, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages) + print(gemm.get_kernel_source(a, b, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages)) + + ref_c = (a.to(torch.float) @ b.to(torch.float)).to(torch.bfloat16) + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All checks passed. ✅") + + tl_latency = do_bench(lambda: gemm(a, b, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages), backend="cupti") + torch_latency = do_bench(lambda: a @ b, backend="cupti") + print(f"Tilelang latency: {tl_latency} ms") + print(f"Flops: {2 * M * N * K / (tl_latency / 1e3) / 1e12} TFLOPS") + print(f"Torch latency: {torch_latency} ms") + print(f"Flops: {2 * M * N * K / (torch_latency / 1e3) / 1e12} TFLOPS") + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py b/examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py new file mode 100644 index 000000000..e95784e6b --- /dev/null +++ b/examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py @@ -0,0 +1,154 @@ +# Persistent, 1-SM, num_epi_stages = 2 + +import torch +import tilelang +import tilelang.language as T +from tilelang.carver.arch import driver +from tilelang.profiler import do_bench + + +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_LOOP_UNSWITCHING: True}) +def gemm( + A, + B, + block_M, + block_N, + store_block_N, # block_N for C_shared + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + use_tma_store=True, +): + M, N, K = T.const("M, N, K") + + A: T.Tensor[[M, K], in_dtype] + B: T.Tensor[[K, N], in_dtype] + C = T.empty((M, N), out_dtype) + + sm_num = driver.get_num_sms() + m_blocks = T.ceildiv(M, block_M) + n_blocks = T.ceildiv(N, block_N) + assert K % (2 * block_K) == 0 # for simplicity + k_blocks = T.ceildiv(K, block_K) + waves = T.ceildiv(m_blocks * n_blocks, sm_num) + group_size = 8 + + with T.Kernel(sm_num, threads=256) as (block_id): + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype) + C_tmem_0 = T.alloc_tmem([block_M, block_N], accum_dtype) + C_tmem_1 = T.alloc_tmem([block_M, block_N], accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) + C_shared = T.alloc_shared((block_M, store_block_N), out_dtype) + loaded = T.alloc_barrier([32] * num_stages) + consumed = T.alloc_barrier([1] * num_stages) + tmem_full = T.alloc_barrier([1] * 2) + tmem_empty = T.alloc_barrier([128] * 2) + + tx = T.get_thread_binding() + + if tx < 32: # warp 0: issue tma + for w in T.unroll(waves): + tile_id = sm_num * w + block_id + bx = (tile_id // group_size) % m_blocks + by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size + + if bx * block_M < M and by * block_N < N: + for k in T.serial(k_blocks): + T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1) + T.copy( + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared[k % num_stages, :, :] + ) # cannot use BufferLoad here + T.copy(B[k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared[k % num_stages, :, :]) + T.mbarrier_arrive(loaded[k % num_stages]) + + elif tx < 64: # warp 1: issue tcgen5 + for w in T.unroll(waves): + tile_id = sm_num * w + block_id + bx = (tile_id // group_size) % m_blocks + by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size + + if bx * block_M < M and by * block_N < N: + T.mbarrier_wait_parity(tmem_empty[w & 1], ((w // 2) & 1) ^ 1) + for k in T.serial(k_blocks): + T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1) + if w & 1 == 0: + T.gemm( + A_shared[k % num_stages, :, :], + B_shared[k % num_stages, :, :], + C_tmem_0, + False, + False, + mbar=consumed[k % num_stages], + wg_wait=-1, + clear_accum=k == 0, + ) + else: + T.gemm( + A_shared[k % num_stages, :, :], + B_shared[k % num_stages, :, :], + C_tmem_1, + False, + False, + mbar=consumed[k % num_stages], + wg_wait=-1, + clear_accum=k == 0, + ) + T.tcgen05_mma_arrive(tmem_full[w & 1]) + + elif 128 <= tx < 256: # warp 4~7: epilogue + for w in T.unroll(waves): + tile_id = sm_num * w + block_id + bx = (tile_id // group_size) % m_blocks + by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size + + if bx * block_M < M and by * block_N < N: + T.mbarrier_wait_parity(tmem_full[w & 1], (w // 2) & 1) + T.sync_threads(1, 128) + if (w & 1) == 0: + T.copy(C_tmem_0, C_local) + else: + T.copy(C_tmem_1, C_local) + T.mbarrier_arrive(tmem_empty[w & 1]) + + if use_tma_store: + for i in T.unroll(T.ceildiv(block_N, store_block_N)): + T.copy(C_local[:, i * store_block_N : (i + 1) * store_block_N], C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N + i * store_block_N]) + else: + T.copy(C_local, C_local_cast) + T.copy(C_local_cast, C[bx * block_M, by * block_N]) + return C + + +def main(): + M, N, K = 8192, 8192, 8192 + block_M, block_N, block_K = 128, 256, 64 + store_block_N = 128 + in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float + num_stages = 4 + + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + print(gemm.get_kernel_source(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages)) + c = gemm(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages) + + ref_c = (a.to(torch.float) @ b.to(torch.float)).to(torch.bfloat16) + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All checks passed. ✅") + + tl_latency = do_bench( + lambda: gemm(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages), backend="cupti" + ) + torch_latency = do_bench(lambda: a @ b, backend="cupti") + print(f"Tilelang latency: {tl_latency} ms") + print(f"Flops: {2 * M * N * K / (tl_latency / 1e3) / 1e12} TFLOPS") + print(f"Torch latency: {torch_latency} ms") + print(f"Flops: {2 * M * N * K / (torch_latency / 1e3) / 1e12} TFLOPS") + + +if __name__ == "__main__": + main() diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 24836f5ba..18c8f9ad2 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -82,11 +82,7 @@ Gemm::Gemm(Array args, Map annotations) { } if (args.size() > 16) { if (const auto *load = args[16].as()) { - node->mbarRegion_ = - NormalizeToBufferRegion(Downcast(args[16])); - node->mbar_ = node->mbarRegion_->buffer; - } else { - node->mbar_ = std::nullopt; + node->mbar_ = Downcast(args[16]); } } node->cCoords_ = Array( @@ -461,7 +457,7 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ICHECK(can_use_tcgen5mma); ICHECK(b_.scope() == "shared.dyn" || b_.scope() == "shared"); ICHECK(c_.scope() == "shared.tmem"); - ICHECK(mbar_.has_value()) << "mbar must be provided for TCGEN5MMA"; + ICHECK(mbar_.defined()) << "mbar must be provided for TCGEN5MMA"; if (a_.scope() == "shared.tmem") { op_name = "tl::tcgen5mma_gemm_ts"; } else if (a_.scope() == "shared.dyn" || a_.scope() == "shared") { @@ -492,8 +488,7 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_; Array new_args; - auto mbarPtr = - MakeAccessPtrFromRegion(mbarRegion_, /*rw*/ 3, /*require_2d*/ true); + auto mbarPtr = MakeAccessPtrFromBufferLoad(mbar_, /*rw*/ 3); new_args.push_back(StringImm(ss.str())); new_args.push_back(Aptr); new_args.push_back(Bptr); diff --git a/src/op/gemm.h b/src/op/gemm.h index fd2733882..cbd07e86b 100644 --- a/src/op/gemm.h +++ b/src/op/gemm.h @@ -130,8 +130,7 @@ class GemmNode : public TileOperatorNode { // only will be enabled under cdna mfma instructions int kPack_ = 1; int wgWait_ = 0; - BufferRegion mbarRegion_; - std::optional mbar_; // mbar is optional, only used for TCGEN5MMA + tir::BufferLoad mbar_; // mbar is optional, only used for TCGEN5MMA Array cCoords_; mutable GemmWarpPolicy policy_; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode); @@ -157,6 +156,8 @@ class GemmNode : public TileOperatorNode { .def_ro("clearAccum", &GemmNode::clearAccum_) .def_ro("kPack", &GemmNode::kPack_) .def_ro("wgWait", &GemmNode::wgWait_) + .def_ro("mbar", &GemmNode::mbar_) + .def_ro("cCoords", &GemmNode::cCoords_) .def_ro("policy", &GemmNode::policy_); } diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index 3d017538b..cede81bd6 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -82,9 +82,7 @@ GemmPy::GemmPy(Array args, Map annotations) { } if (args.size() > 16) { if (const auto *load = args[16].as()) { - node->mbarRegion_ = - NormalizeToBufferRegion(Downcast(args[16])); - node->mbar_ = node->mbarRegion_->buffer; + node->mbar_ = Downcast(args[16]); } } node->cCoords_ = Array( diff --git a/src/op/gemm_py.h b/src/op/gemm_py.h index d6468a0bf..0ad555ea8 100644 --- a/src/op/gemm_py.h +++ b/src/op/gemm_py.h @@ -29,8 +29,7 @@ class GemmPyNode : public TileOperatorNode { int strideA_, strideB_; int offsetA_, offsetB_; PrimExpr clearAccum_ = const_false(); - BufferRegion mbarRegion_; - tir::Buffer mbar_; // mbar is optional, only used for TCGEN5MMA + tir::BufferLoad mbar_; // mbar is optional, only used for TCGEN5MMA Array cCoords_; // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack // only will be enabled under cdna mfma instructions @@ -59,7 +58,6 @@ class GemmPyNode : public TileOperatorNode { .def_ro("offsetA", &GemmPyNode::offsetA_) .def_ro("offsetB", &GemmPyNode::offsetB_) .def_ro("clearAccum", &GemmPyNode::clearAccum_) - .def_ro("mbarRegion", &GemmPyNode::mbarRegion_) .def_ro("mbar", &GemmPyNode::mbar_) .def_ro("cCoords", &GemmPyNode::cCoords_) .def_ro("kPack", &GemmPyNode::kPack_) diff --git a/src/op/utils.cc b/src/op/utils.cc index 7f8c3c7c6..309d34662 100644 --- a/src/op/utils.cc +++ b/src/op/utils.cc @@ -103,6 +103,36 @@ PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, int rw_mask, return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); } +PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) { + Buffer buf = load->buffer; + int ndim = static_cast(buf->shape.size()); + + // Compute offset using row-major layout (iterate in reverse) + PrimExpr offset = 0; + PrimExpr stride = 1; + + for (int i = ndim - 1; i >= 0; --i) { + const PrimExpr &index = load->indices[i]; + if (const auto *ramp = index.as()) { + // For Ramp, use the base + offset = offset + ramp->base * stride; + } else { + // For scalar index (IntImm or other PrimExpr) + offset = offset + index * stride; + } + stride = stride * buf->shape[i]; + } + + // Extent is 1 element for a single BufferLoad access + PrimExpr extent = make_const(DataType::Int(32), 1); + + // Build access_ptr + PrimExpr ptype = tir::TypeAnnotation(buf->dtype); + Array acc_args{ptype, buf->data, offset, extent, + IntImm(DataType::Int(32), rw_mask)}; + return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); +} + // Maps TVM DataType to CUDA's CUtensorMapDataType enum value. int to_CUtensorMapDataType(DataType dtype) { CUtensorMapDataType tp; diff --git a/src/op/utils.h b/src/op/utils.h index 9fdb3b4af..f627d702f 100644 --- a/src/op/utils.h +++ b/src/op/utils.h @@ -43,6 +43,10 @@ TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg); TVM_DLL PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, int rw_mask, bool require_2d = false); +// Build a tvm_access_ptr(handle) from a BufferLoad. +TVM_DLL PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, + int rw_mask); + // Check if a buffer is a fragment buffer (scope == "local.fragment") inline bool IsFragmentBuffer(const Buffer &buffer) { return buffer.defined() && buffer.scope() == "local.fragment"; diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index fef5dc983..9df3a5498 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -805,15 +805,17 @@ def cp_async_barrier_noinc(barrier: BarrierType): return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier) -def tcgen05_mma_arrive(mbar_ptr): +def tcgen05_mma_arrive(mbar: tir.Buffer | BufferLoad | PrimExpr): """Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer. Parameters ---------- - mbar_ptr : PrimExpr - Pointer to the mbarrier object in shared memory (e.g., Barrier*). + mbar: tir.Buffer | BufferLoad | PrimExpr + The mbarrier object in shared memory (e.g., Barrier*) or its address. """ - return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr) + if isinstance(mbar, (tir.Buffer, BufferLoad)): + mbar = retrieve_ptr(mbar, access_type="rw") + return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar) def ptx_mma_sm70( diff --git a/tilelang/language/gemm_op.py b/tilelang/language/gemm_op.py index dc02639c6..fe7be5b37 100644 --- a/tilelang/language/gemm_op.py +++ b/tilelang/language/gemm_op.py @@ -101,13 +101,15 @@ def legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType: f"mbar for tcgen5mma must be a tir.Buffer or tir.BufferLoad, but got {type(mbar)}" ) mbar = to_buffer_region(mbar, access_type="rw") - else: - mbar = tir.const(0, T.uint32) C_coords = [r.min for r in C_region.region] # Convert BufferRegion to tl.region calls for arguments A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape]) C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) + # When mbar is None, pass a placeholder constant (0). + # The C++ side checks if arg 16 is a BufferLoadNode before using it, + # so a non-BufferLoad value will be correctly ignored. + mbar_arg = mbar if mbar is not None else tir.const(0, dtype="int32") return tir.call_intrin( "handle", tir.op.Op.get(op_key), @@ -127,7 +129,7 @@ def legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType: offset_b, k_pack, wg_wait, - mbar, + mbar_arg, C_coords[0], C_coords[1], ) diff --git a/tilelang/tileop/gemm/gemm_base.py b/tilelang/tileop/gemm/gemm_base.py index 841183738..0bea9e559 100644 --- a/tilelang/tileop/gemm/gemm_base.py +++ b/tilelang/tileop/gemm/gemm_base.py @@ -128,7 +128,7 @@ def mbarptr(self) -> PrimExpr: return getattr(self.gemm_node, "mbarPtr", tvm.tir.const(0, T.uint32)) @property - def mbar(self) -> tir.Buffer | tir.BufferLoad: + def mbar(self) -> tir.BufferLoad | None: return getattr(self.gemm_node, "mbar", None) @property diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py index 32bef1cbd..3bc05e4a5 100644 --- a/tilelang/tileop/gemm/gemm_tcgen05.py +++ b/tilelang/tileop/gemm/gemm_tcgen05.py @@ -5,6 +5,7 @@ TensorCoreIntrinEmitter, ) from tilelang import language as T +from tilelang.utils.language import retrieve_ptr from tilelang.transform.simplify import _Simplify from tvm import tir from tvm.target import Target @@ -93,10 +94,10 @@ def lower(self, layout_map: dict, target: Target, thread_bounds: Range, thread_v raise ValueError("TCGEN5MMA currently requires wg_wait == -1") mbar = self.mbar - if mbar == 0: + if mbar is None: raise ValueError("TCGEN5MMA requires a valid mbarrier") - mbarptr = mbar.access_ptr("rw") + mbarptr = retrieve_ptr(mbar, "rw") C_coords = self.C_coords if len(C_coords) != 2: