-
Notifications
You must be signed in to change notification settings - Fork 438
[Example][BugFix] 1SM GEMM example on Blackwell and fix handling of mbar
#1774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
aa6a6fb
fb51314
4d3bcba
9be2578
f83d910
2c6ab8e
cac3e3f
4f78480
27755e8
50ca473
b14a738
4927749
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| # Non-persistent, 1-SM GEMM | ||
|
|
||
| import torch | ||
| import tilelang | ||
| import tilelang.language as T | ||
| from tilelang.profiler import do_bench | ||
|
|
||
|
|
||
| @tilelang.jit | ||
| def gemm(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages, use_tma_store=True): | ||
| M, N, K = T.const("M, N, K") | ||
|
|
||
| k_iters = T.ceildiv(K, block_K) | ||
|
|
||
| A: T.Tensor[[M, K], in_dtype] | ||
| B: T.Tensor[[K, N], in_dtype] | ||
| C = T.empty((M, N), out_dtype) | ||
|
|
||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): | ||
| A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) | ||
| B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype) | ||
| C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) | ||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
| C_shared = T.alloc_shared((block_M, block_N), out_dtype) | ||
| C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) | ||
| loaded = T.alloc_barrier([32] * num_stages) | ||
| consumed = T.alloc_barrier([1] * num_stages) | ||
| tmem_full = T.alloc_barrier([1]) | ||
|
|
||
| tx = T.get_thread_binding() | ||
|
|
||
| T.use_swizzle(8) | ||
|
|
||
| if tx < 32: # warp 0: issue tma | ||
| for k in T.serial(k_iters): | ||
| T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1) | ||
| T.copy(A[by * block_M : (by + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared[k % num_stages, :, :]) | ||
| T.copy(B[k * block_K : (k + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[k % num_stages, :, :]) | ||
| T.mbarrier_arrive(loaded[k % num_stages]) | ||
| elif tx < 64: # warp 1: issue tcgen5 | ||
| for k in T.serial(k_iters): | ||
| T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1) | ||
| T.gemm( | ||
| A_shared[k % num_stages, :, :], | ||
| B_shared[k % num_stages, :, :], | ||
| C_tmem, | ||
| mbar=consumed[k % num_stages], | ||
| wg_wait=-1, | ||
| clear_accum=k == 0, | ||
| ) | ||
| T.tcgen05_mma_arrive(tmem_full) | ||
|
|
||
| # Wait for all tcgen5 to finish | ||
| T.mbarrier_wait_parity(tmem_full, 0) | ||
|
|
||
| T.sync_threads() # TileLang won't generate this if not annotated | ||
| T.copy(C_tmem, C_local) | ||
| if use_tma_store: | ||
| T.copy(C_local, C_shared) | ||
| T.copy(C_shared, C[by * block_M, bx * block_N]) | ||
| else: | ||
| T.copy(C_local, C_local_cast) | ||
| T.copy(C_local_cast, C[by * block_M, bx * block_N]) | ||
| return C | ||
|
|
||
|
|
||
| def main(): | ||
| M, N, K = 8192, 8192, 8192 | ||
| block_M, block_N, block_K = 128, 256, 64 | ||
| in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float | ||
| num_stages = 4 | ||
|
|
||
| a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) | ||
| b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) | ||
| c = gemm(a, b, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages) | ||
| print(gemm.get_kernel_source(a, b, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages)) | ||
|
|
||
| ref_c = (a.to(torch.float) @ b.to(torch.float)).to(torch.bfloat16) | ||
| torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) | ||
| print("All checks passed. ✅") | ||
|
|
||
| tl_latency = do_bench(lambda: gemm(a, b, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages), backend="cupti") | ||
| torch_latency = do_bench(lambda: a @ b, backend="cupti") | ||
| print(f"Tilelang latency: {tl_latency} ms") | ||
| print(f"Flops: {2 * M * N * K / (tl_latency / 1e3) / 1e12} TFLOPS") | ||
| print(f"Torch latency: {torch_latency} ms") | ||
| print(f"Flops: {2 * M * N * K / (torch_latency / 1e3) / 1e12} TFLOPS") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| # Persistent, 1-SM, num_epi_stages = 2 | ||
|
|
||
| import torch | ||
| import tilelang | ||
| import tilelang.language as T | ||
| from tilelang.carver.arch import driver | ||
| from tilelang.profiler import do_bench | ||
|
|
||
|
|
||
| @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_LOOP_UNSWITCHING: True}) | ||
| def gemm( | ||
| A, | ||
| B, | ||
| block_M, | ||
| block_N, | ||
| store_block_N, # block_N for C_shared | ||
| block_K, | ||
| in_dtype, | ||
| out_dtype, | ||
| accum_dtype, | ||
| num_stages, | ||
| use_tma_store=True, | ||
| ): | ||
| M, N, K = T.const("M, N, K") | ||
|
|
||
| A: T.Tensor[[M, K], in_dtype] | ||
| B: T.Tensor[[K, N], in_dtype] | ||
| C = T.empty((M, N), out_dtype) | ||
|
|
||
| sm_num = driver.get_num_sms() | ||
| m_blocks = T.ceildiv(M, block_M) | ||
| n_blocks = T.ceildiv(N, block_N) | ||
| assert K % (2 * block_K) == 0 # for simplicity | ||
|
Comment on lines
+31
to
+33
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: cd /tmp/repo && find . -name "gemm_tcgen5mma_ws_persistent.py" -type fRepository: tile-ai/tilelang Length of output: 119 🏁 Script executed: head -150 examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py | cat -nRepository: tile-ai/tilelang Length of output: 7610 Add tiling precondition checks to prevent OOB loads/stores. The kernel loads/stores full tiles without boundary masking. If 💡 Suggested guardrails- assert K % (2 * block_K) == 0 # for simplicity
+ assert K % (2 * block_K) == 0, "K must be divisible by 2 * block_K"
+ assert M % block_M == 0 and N % block_N == 0, "M/N must be multiples of block_M/block_N"
+ assert store_block_N <= block_N and block_N % store_block_N == 0, "store_block_N must evenly divide block_N"🤖 Prompt for AI Agents |
||
| k_blocks = T.ceildiv(K, block_K) | ||
| waves = T.ceildiv(m_blocks * n_blocks, sm_num) | ||
| group_size = 8 | ||
|
|
||
| with T.Kernel(sm_num, threads=256) as (block_id): | ||
| A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) | ||
| B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype) | ||
| C_tmem_0 = T.alloc_tmem([block_M, block_N], accum_dtype) | ||
| C_tmem_1 = T.alloc_tmem([block_M, block_N], accum_dtype) | ||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
| C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) | ||
| C_shared = T.alloc_shared((block_M, store_block_N), out_dtype) | ||
| loaded = T.alloc_barrier([32] * num_stages) | ||
| consumed = T.alloc_barrier([1] * num_stages) | ||
| tmem_full = T.alloc_barrier([1] * 2) | ||
| tmem_empty = T.alloc_barrier([128] * 2) | ||
|
|
||
| tx = T.get_thread_binding() | ||
|
|
||
| if tx < 32: # warp 0: issue tma | ||
| for w in T.unroll(waves): | ||
| tile_id = sm_num * w + block_id | ||
| bx = (tile_id // group_size) % m_blocks | ||
| by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size | ||
|
|
||
| if bx * block_M < M and by * block_N < N: | ||
| for k in T.serial(k_blocks): | ||
| T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1) | ||
| T.copy( | ||
| A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared[k % num_stages, :, :] | ||
| ) # cannot use BufferLoad here | ||
| T.copy(B[k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared[k % num_stages, :, :]) | ||
| T.mbarrier_arrive(loaded[k % num_stages]) | ||
|
|
||
| elif tx < 64: # warp 1: issue tcgen5 | ||
| for w in T.unroll(waves): | ||
| tile_id = sm_num * w + block_id | ||
| bx = (tile_id // group_size) % m_blocks | ||
| by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size | ||
|
|
||
| if bx * block_M < M and by * block_N < N: | ||
| T.mbarrier_wait_parity(tmem_empty[w & 1], ((w // 2) & 1) ^ 1) | ||
| for k in T.serial(k_blocks): | ||
| T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1) | ||
| if w & 1 == 0: | ||
| T.gemm( | ||
| A_shared[k % num_stages, :, :], | ||
| B_shared[k % num_stages, :, :], | ||
| C_tmem_0, | ||
| False, | ||
| False, | ||
| mbar=consumed[k % num_stages], | ||
| wg_wait=-1, | ||
| clear_accum=k == 0, | ||
| ) | ||
| else: | ||
| T.gemm( | ||
| A_shared[k % num_stages, :, :], | ||
| B_shared[k % num_stages, :, :], | ||
| C_tmem_1, | ||
| False, | ||
| False, | ||
| mbar=consumed[k % num_stages], | ||
| wg_wait=-1, | ||
| clear_accum=k == 0, | ||
| ) | ||
| T.tcgen05_mma_arrive(tmem_full[w & 1]) | ||
|
|
||
| elif 128 <= tx < 256: # warp 4~7: epilogue | ||
| for w in T.unroll(waves): | ||
| tile_id = sm_num * w + block_id | ||
| bx = (tile_id // group_size) % m_blocks | ||
| by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size | ||
|
|
||
| if bx * block_M < M and by * block_N < N: | ||
| T.mbarrier_wait_parity(tmem_full[w & 1], (w // 2) & 1) | ||
| T.sync_threads(1, 128) | ||
| if (w & 1) == 0: | ||
| T.copy(C_tmem_0, C_local) | ||
| else: | ||
| T.copy(C_tmem_1, C_local) | ||
| T.mbarrier_arrive(tmem_empty[w & 1]) | ||
|
|
||
| if use_tma_store: | ||
| for i in T.unroll(T.ceildiv(block_N, store_block_N)): | ||
| T.copy(C_local[:, i * store_block_N : (i + 1) * store_block_N], C_shared) | ||
| T.copy(C_shared, C[bx * block_M, by * block_N + i * store_block_N]) | ||
| else: | ||
| T.copy(C_local, C_local_cast) | ||
| T.copy(C_local_cast, C[bx * block_M, by * block_N]) | ||
| return C | ||
|
|
||
|
|
||
| def main(): | ||
| M, N, K = 8192, 8192, 8192 | ||
| block_M, block_N, block_K = 128, 256, 64 | ||
| store_block_N = 128 | ||
| in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float | ||
| num_stages = 4 | ||
|
|
||
| a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) | ||
| b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) | ||
| print(gemm.get_kernel_source(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages)) | ||
| c = gemm(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages) | ||
|
|
||
| ref_c = (a.to(torch.float) @ b.to(torch.float)).to(torch.bfloat16) | ||
| torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) | ||
| print("All checks passed. ✅") | ||
|
|
||
| tl_latency = do_bench( | ||
| lambda: gemm(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages), backend="cupti" | ||
| ) | ||
| torch_latency = do_bench(lambda: a @ b, backend="cupti") | ||
| print(f"Tilelang latency: {tl_latency} ms") | ||
| print(f"Flops: {2 * M * N * K / (tl_latency / 1e3) / 1e12} TFLOPS") | ||
| print(f"Torch latency: {torch_latency} ms") | ||
| print(f"Flops: {2 * M * N * K / (torch_latency / 1e3) / 1e12} TFLOPS") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -82,11 +82,7 @@ Gemm::Gemm(Array<PrimExpr> args, Map<String, ObjectRef> annotations) { | |||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| if (args.size() > 16) { | ||||||||||||||||||||||||||||||||||||
| if (const auto *load = args[16].as<BufferLoadNode>()) { | ||||||||||||||||||||||||||||||||||||
| node->mbarRegion_ = | ||||||||||||||||||||||||||||||||||||
| NormalizeToBufferRegion(Downcast<BufferLoad>(args[16])); | ||||||||||||||||||||||||||||||||||||
| node->mbar_ = node->mbarRegion_->buffer; | ||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||
| node->mbar_ = std::nullopt; | ||||||||||||||||||||||||||||||||||||
| node->mbar_ = Downcast<BufferLoad>(args[16]); | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
83
to
86
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate If 🔧 Suggested guard if (args.size() > 16) {
- if (const auto *load = args[16].as<BufferLoadNode>()) {
- node->mbar_ = Downcast<BufferLoad>(args[16]);
- } else {
- node->mbar_ = std::nullopt;
- }
+ const auto *load = args[16].as<BufferLoadNode>();
+ ICHECK(load || !args[16].defined())
+ << "mbar must be a BufferLoad or null when provided";
+ node->mbar_ =
+ load ? std::optional<BufferLoad>(Downcast<BufferLoad>(args[16]))
+ : std::nullopt;
}📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| node->cCoords_ = Array<PrimExpr>( | ||||||||||||||||||||||||||||||||||||
|
|
@@ -461,7 +457,7 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { | |||||||||||||||||||||||||||||||||||
| ICHECK(can_use_tcgen5mma); | ||||||||||||||||||||||||||||||||||||
| ICHECK(b_.scope() == "shared.dyn" || b_.scope() == "shared"); | ||||||||||||||||||||||||||||||||||||
| ICHECK(c_.scope() == "shared.tmem"); | ||||||||||||||||||||||||||||||||||||
| ICHECK(mbar_.has_value()) << "mbar must be provided for TCGEN5MMA"; | ||||||||||||||||||||||||||||||||||||
| ICHECK(mbar_.defined()) << "mbar must be provided for TCGEN5MMA"; | ||||||||||||||||||||||||||||||||||||
| if (a_.scope() == "shared.tmem") { | ||||||||||||||||||||||||||||||||||||
| op_name = "tl::tcgen5mma_gemm_ts"; | ||||||||||||||||||||||||||||||||||||
| } else if (a_.scope() == "shared.dyn" || a_.scope() == "shared") { | ||||||||||||||||||||||||||||||||||||
|
|
@@ -492,8 +488,7 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { | |||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_; | ||||||||||||||||||||||||||||||||||||
| Array<PrimExpr> new_args; | ||||||||||||||||||||||||||||||||||||
| auto mbarPtr = | ||||||||||||||||||||||||||||||||||||
| MakeAccessPtrFromRegion(mbarRegion_, /*rw*/ 3, /*require_2d*/ true); | ||||||||||||||||||||||||||||||||||||
| auto mbarPtr = MakeAccessPtrFromBufferLoad(mbar_, /*rw*/ 3); | ||||||||||||||||||||||||||||||||||||
| new_args.push_back(StringImm(ss.str())); | ||||||||||||||||||||||||||||||||||||
| new_args.push_back(Aptr); | ||||||||||||||||||||||||||||||||||||
| new_args.push_back(Bptr); | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -82,9 +82,7 @@ GemmPy::GemmPy(Array<PrimExpr> args, Map<String, ObjectRef> annotations) { | |||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| if (args.size() > 16) { | ||||||||||||||||||||||||||
| if (const auto *load = args[16].as<BufferLoadNode>()) { | ||||||||||||||||||||||||||
| node->mbarRegion_ = | ||||||||||||||||||||||||||
| NormalizeToBufferRegion(Downcast<BufferLoad>(args[16])); | ||||||||||||||||||||||||||
| node->mbar_ = node->mbarRegion_->buffer; | ||||||||||||||||||||||||||
| node->mbar_ = Downcast<BufferLoad>(args[16]); | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
Comment on lines
83
to
86
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fail fast when Silently ignoring non-BufferLoad inputs can mask call-site errors and lead to a missing barrier later; an explicit check makes failures clearer. 💡 Suggested fix- if (args.size() > 16) {
- if (const auto *load = args[16].as<BufferLoadNode>()) {
- node->mbar_ = Downcast<BufferLoad>(args[16]);
- }
- }
+ if (args.size() > 16) {
+ ICHECK(args[16].as<BufferLoadNode>())
+ << "mbar must be provided as a BufferLoad when present";
+ node->mbar_ = Downcast<BufferLoad>(args[16]);
+ }📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| node->cCoords_ = Array<PrimExpr>( | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,8 +29,7 @@ class GemmPyNode : public TileOperatorNode { | |
| int strideA_, strideB_; | ||
| int offsetA_, offsetB_; | ||
| PrimExpr clearAccum_ = const_false(); | ||
| BufferRegion mbarRegion_; | ||
| tir::Buffer mbar_; // mbar is optional, only used for TCGEN5MMA | ||
| tir::BufferLoad mbar_; // mbar is optional, only used for TCGEN5MMA | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Verify mbar_ usage for null-safety and initialization paths
rg -n -C3 'mbar_' src/op/gemm_py.cc src/op/gemm_py.hRepository: tile-ai/tilelang Length of output: 1338 🏁 Script executed: # Also check the GemmNode definition to see the original type
rg -n -C3 'struct GemmNode\|class GemmNode' src/op/gemm_py.h src/op/gemm_py.ccRepository: tile-ai/tilelang Length of output: 42 🏁 Script executed: # Look for mbar_ initialization in constructors or Create methods
rg -n 'mbar_\s*=' src/op/gemm_py.cc src/op/gemm_py.hRepository: tile-ai/tilelang Length of output: 132 🏁 Script executed: # Check if .defined() or other null checks are used with mbar_
rg -n 'mbar_.*defined|if.*mbar_' src/op/gemm_py.cc src/op/gemm_py.hRepository: tile-ai/tilelang Length of output: 42 🏁 Script executed: # Search for how mbar_ is used/read (not just assigned)
rg -n 'mbar_' src/op/gemm_py.cc src/op/gemm_py.h -A 5 -B 2Repository: tile-ai/tilelang Length of output: 1511 🏁 Script executed: # Look for uses of mbar in generated IR or kernel code
rg -n 'mbar' src/op/gemm_py.cc src/op/gemm_py.hRepository: tile-ai/tilelang Length of output: 286 🏁 Script executed: # Check if there's a TVM BufferLoad definition to understand default initialization
fd -t f '\.h$' -x grep -l 'class BufferLoad\|struct BufferLoad' {} \; | head -5Repository: tile-ai/tilelang Length of output: 42 🏁 Script executed: # Search for mbar usage in gemm_py.cc to find where it's read/accessed
rg -n 'mbar' src/op/gemm_py.cc | head -20Repository: tile-ai/tilelang Length of output: 114 🏁 Script executed: # Search for GemmPyNode usage in other files to see how mbar is accessed
rg -n 'GemmPyNode\|mbar' src/ --include='*.cc' --include='*.h' | grep -v 'gemm_py' | head -20Repository: tile-ai/tilelang Length of output: 501 🏁 Script executed: # Check the entire gemm_py.cc to see function context where mbar is used
wc -l src/op/gemm_py.ccRepository: tile-ai/tilelang Length of output: 81 🏁 Script executed: cat -n src/op/gemm_py.ccRepository: tile-ai/tilelang Length of output: 15164 Resolve "optional" mismatch for
Suggestion: Make this 🤖 Prompt for AI Agents |
||
| Array<PrimExpr> cCoords_; | ||
| // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack | ||
| // only will be enabled under cdna mfma instructions | ||
|
|
@@ -59,7 +58,6 @@ class GemmPyNode : public TileOperatorNode { | |
| .def_ro("offsetA", &GemmPyNode::offsetA_) | ||
| .def_ro("offsetB", &GemmPyNode::offsetB_) | ||
| .def_ro("clearAccum", &GemmPyNode::clearAccum_) | ||
| .def_ro("mbarRegion", &GemmPyNode::mbarRegion_) | ||
| .def_ro("mbar", &GemmPyNode::mbar_) | ||
| .def_ro("cCoords", &GemmPyNode::cCoords_) | ||
| .def_ro("kPack", &GemmPyNode::kPack_) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -103,6 +103,36 @@ PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, int rw_mask, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Buffer buf = load->buffer; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int ndim = static_cast<int>(buf->shape.size()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Compute offset using row-major layout (iterate in reverse) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr offset = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr stride = 1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = ndim - 1; i >= 0; --i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const PrimExpr &index = load->indices[i]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (const auto *ramp = index.as<RampNode>()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // For Ramp, use the base | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| offset = offset + ramp->base * stride; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // For scalar index (IntImm or other PrimExpr) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| offset = offset + index * stride; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stride = stride * buf->shape[i]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Extent is 1 element for a single BufferLoad access | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr extent = make_const(DataType::Int(32), 1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Build access_ptr | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr ptype = tir::TypeAnnotation(buf->dtype); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Array<PrimExpr> acc_args{ptype, buf->data, offset, extent, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IntImm(DataType::Int(32), rw_mask)}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
106
to
133
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid int32 offset/stride to prevent overflow on large buffers.
💡 Suggested fix- // Compute offset using row-major layout (iterate in reverse)
- PrimExpr offset = make_const(DataType::Int(32), 0);
- PrimExpr stride = make_const(DataType::Int(32), 1);
+ // Compute offset using row-major layout (iterate in reverse)
+ DataType idx_dtype = buf->shape[0].dtype();
+ PrimExpr offset = make_const(idx_dtype, 0);
+ PrimExpr stride = make_const(idx_dtype, 1);
@@
- // Extent is 1 element for a single BufferLoad access
- PrimExpr extent = make_const(DataType::Int(32), 1);
+ // Extent is 1 element for a single BufferLoad access
+ PrimExpr extent = make_const(idx_dtype, 1);📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Maps TVM DataType to CUDA's CUtensorMapDataType enum value. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int to_CUtensorMapDataType(DataType dtype) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| CUtensorMapDataType tp; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 100
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 1244
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 1130
🏁 Script executed:
# Examine the file mentioned in the review cat -n examples/gemm_sm100/gemm_tcgen5mma_ws.pyRepository: tile-ai/tilelang
Length of output: 4292
🏁 Script executed:
# Also search for other gemm_sm100 files to get context ls -la examples/gemm_sm100/Repository: tile-ai/tilelang
Length of output: 434
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 2122
🏁 Script executed:
# Look at other SM100 GEMM variants to see if they have guards cat -n examples/gemm_sm100/gemm_tcgen5mma.pyRepository: tile-ai/tilelang
Length of output: 3381
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 3576
Add explicit precondition to guard against non‑divisible M/N/K dimensions.
The kernel uses
T.ceildivfor grid dimensions, and line 68 copiesC_shared(shape[block_M, block_N]) toC[by * block_M, bx * block_N]without bounds checking. If M/N/K are not exact multiples of their block sizes, the last threadblocks will write out of bounds. The test case passes only because M=N=K=16384 are exact multiples of their block sizes. Add an assertion to enforce this precondition, as done in other GEMM kernels in the codebase.🛡️ Suggested guard (adjust to TileLang-supported assertion APIs)
def gemm( A, B, @@ ): M, N, K = T.const("M, N, K") + # Kernel requires exact tile coverage to avoid OOB access at line 68. + assert ( + M % block_M == 0 + and N % block_N == 0 + and K % block_K == 0 + ), "M/N/K must be divisible by block_M/block_N/block_K"🤖 Prompt for AI Agents