Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions examples/gemm_sm100/gemm_tcgen5mma_ws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Non-persistent, 1-SM GEMM

import torch
import tilelang
import tilelang.language as T
from tilelang.profiler import do_bench


@tilelang.jit
def gemm(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages, use_tma_store=True):
M, N, K = T.const("M, N, K")

k_iters = T.ceildiv(K, block_K)

A: T.Tensor[[M, K], in_dtype]
B: T.Tensor[[K, N], in_dtype]
C = T.empty((M, N), out_dtype)

with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype)
B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype)
loaded = T.alloc_barrier([32] * num_stages)
consumed = T.alloc_barrier([1] * num_stages)
tmem_full = T.alloc_barrier([1])

tx = T.get_thread_binding()

T.use_swizzle(8)

if tx < 32: # warp 0: issue tma
for k in T.serial(k_iters):
T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1)
T.copy(A[by * block_M : (by + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared[k % num_stages, :, :])
T.copy(B[k * block_K : (k + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[k % num_stages, :, :])
Comment on lines 13 to 38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Find the file mentioned in the review
fd "gemm_tcgen5mma_ws.py" --type f

Repository: tile-ai/tilelang

Length of output: 100


🏁 Script executed:

# List the repository structure to understand the codebase layout
ls -la examples/ 2>/dev/null | head -20

Repository: tile-ai/tilelang

Length of output: 1244


🏁 Script executed:

# Search for other GEMM examples to understand the pattern
fd "gemm.*\.py" examples/ --type f | head -20

Repository: tile-ai/tilelang

Length of output: 1130


🏁 Script executed:

# Examine the file mentioned in the review
cat -n examples/gemm_sm100/gemm_tcgen5mma_ws.py

Repository: tile-ai/tilelang

Length of output: 4292


🏁 Script executed:

# Also search for other gemm_sm100 files to get context
ls -la examples/gemm_sm100/

Repository: tile-ai/tilelang

Length of output: 434


🏁 Script executed:

# Search for other GEMM examples to understand how they handle non-divisible dimensions
grep -r "T.copy.*\[.*\*.*block_" examples/gemm* --include="*.py" | head -20

Repository: tile-ai/tilelang

Length of output: 2122


🏁 Script executed:

# Look at other SM100 GEMM variants to see if they have guards
cat -n examples/gemm_sm100/gemm_tcgen5mma.py

Repository: tile-ai/tilelang

Length of output: 3381


🏁 Script executed:

# Search for any guard patterns (assertions, checks) in GEMM examples
rg "assert.*%|assert.*divisible|if.*%" examples/gemm* -A 2 -B 2 | head -40

Repository: tile-ai/tilelang

Length of output: 3576


Add explicit precondition to guard against non‑divisible M/N/K dimensions.
The kernel uses T.ceildiv for grid dimensions, and line 68 copies C_shared (shape [block_M, block_N]) to C[by * block_M, bx * block_N] without bounds checking. If M/N/K are not exact multiples of their block sizes, the last threadblocks will write out of bounds. The test case passes only because M=N=K=16384 are exact multiples of their block sizes. Add an assertion to enforce this precondition, as done in other GEMM kernels in the codebase.

🛡️ Suggested guard (adjust to TileLang-supported assertion APIs)
 def gemm(
     A,
     B,
@@
 ):
     M, N, K = T.const("M, N, K")
+    # Kernel requires exact tile coverage to avoid OOB access at line 68.
+    assert (
+        M % block_M == 0
+        and N % block_N == 0
+        and K % block_K == 0
+    ), "M/N/K must be divisible by block_M/block_N/block_K"
🤖 Prompt for AI Agents
In `@examples/gemm_sm100/gemm_tcgen5mma_ws.py` around lines 23 - 47, Add an
explicit precondition that M, N, and K are exact multiples of block_M, block_N,
and block_K respectively to prevent out-of-bounds writes when the kernel (which
computes k_iters via T.ceildiv and writes C_shared -> C at by*block_M,
bx*block_N) runs; insert an assertion near the top of the function before the
T.Kernel block (where k_iters, A/B/C and shared buffers are set up) that checks
M % block_M == 0, N % block_N == 0, and K % block_K == 0 and fail early if not,
mirroring the guards used in other GEMM kernels.

T.mbarrier_arrive(loaded[k % num_stages])
elif tx < 64: # warp 1: issue tcgen5
for k in T.serial(k_iters):
T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1)
T.gemm(
A_shared[k % num_stages, :, :],
B_shared[k % num_stages, :, :],
C_tmem,
mbar=consumed[k % num_stages],
wg_wait=-1,
clear_accum=k == 0,
)
T.tcgen05_mma_arrive(tmem_full)

# Wait for all tcgen5 to finish
T.mbarrier_wait_parity(tmem_full, 0)

T.sync_threads() # TileLang won't generate this if not annotated
T.copy(C_tmem, C_local)
if use_tma_store:
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
else:
T.copy(C_local, C_local_cast)
T.copy(C_local_cast, C[by * block_M, bx * block_N])
return C


def main():
M, N, K = 8192, 8192, 8192
block_M, block_N, block_K = 128, 256, 64
in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float
num_stages = 4

a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
c = gemm(a, b, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages)
print(gemm.get_kernel_source(a, b, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages))

ref_c = (a.to(torch.float) @ b.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All checks passed. ✅")

tl_latency = do_bench(lambda: gemm(a, b, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages), backend="cupti")
torch_latency = do_bench(lambda: a @ b, backend="cupti")
print(f"Tilelang latency: {tl_latency} ms")
print(f"Flops: {2 * M * N * K / (tl_latency / 1e3) / 1e12} TFLOPS")
print(f"Torch latency: {torch_latency} ms")
print(f"Flops: {2 * M * N * K / (torch_latency / 1e3) / 1e12} TFLOPS")


if __name__ == "__main__":
main()
154 changes: 154 additions & 0 deletions examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Persistent, 1-SM, num_epi_stages = 2

import torch
import tilelang
import tilelang.language as T
from tilelang.carver.arch import driver
from tilelang.profiler import do_bench


@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_LOOP_UNSWITCHING: True})
def gemm(
A,
B,
block_M,
block_N,
store_block_N, # block_N for C_shared
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
use_tma_store=True,
):
M, N, K = T.const("M, N, K")

A: T.Tensor[[M, K], in_dtype]
B: T.Tensor[[K, N], in_dtype]
C = T.empty((M, N), out_dtype)

sm_num = driver.get_num_sms()
m_blocks = T.ceildiv(M, block_M)
n_blocks = T.ceildiv(N, block_N)
assert K % (2 * block_K) == 0 # for simplicity
Comment on lines +31 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd /tmp/repo && find . -name "gemm_tcgen5mma_ws_persistent.py" -type f

Repository: tile-ai/tilelang

Length of output: 119


🏁 Script executed:

head -150 examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py | cat -n

Repository: tile-ai/tilelang

Length of output: 7610


Add tiling precondition checks to prevent OOB loads/stores.

The kernel loads/stores full tiles without boundary masking. If M or N aren't multiples of block_M or block_N, the slices on lines 63 and 65 will read past array bounds. Similarly, on line 119, if store_block_N doesn't evenly divide block_N, the loop will access out-of-bounds memory in C_local. While the test case uses compatible parameters (8192 with block size 128/256, store_block_N 128), the kernel interface accepts these as runtime parameters, allowing unsafe calls.

💡 Suggested guardrails
-    assert K % (2 * block_K) == 0  # for simplicity
+    assert K % (2 * block_K) == 0, "K must be divisible by 2 * block_K"
+    assert M % block_M == 0 and N % block_N == 0, "M/N must be multiples of block_M/block_N"
+    assert store_block_N <= block_N and block_N % store_block_N == 0, "store_block_N must evenly divide block_N"
🤖 Prompt for AI Agents
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py` around lines 31 - 33,
Add explicit tiling precondition checks to prevent out-of-bounds tile
loads/stores: validate at runtime that M % block_M == 0 and N % block_N == 0 to
ensure full tiles for the slices used when loading/storing (referencing m_blocks
and n_blocks), and also assert that block_N % store_block_N == 0 (to prevent OOB
access into C_local when iterating store_block_N across a block_N tile).
Implement these as early assertions/argument checks (or raise clear exceptions)
before any tensor slicing or the compute loop so callers supplying M, N,
block_M, block_N, and store_block_N cannot trigger OOB accesses.

k_blocks = T.ceildiv(K, block_K)
waves = T.ceildiv(m_blocks * n_blocks, sm_num)
group_size = 8

with T.Kernel(sm_num, threads=256) as (block_id):
A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype)
B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype)
C_tmem_0 = T.alloc_tmem([block_M, block_N], accum_dtype)
C_tmem_1 = T.alloc_tmem([block_M, block_N], accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype)
C_shared = T.alloc_shared((block_M, store_block_N), out_dtype)
loaded = T.alloc_barrier([32] * num_stages)
consumed = T.alloc_barrier([1] * num_stages)
tmem_full = T.alloc_barrier([1] * 2)
tmem_empty = T.alloc_barrier([128] * 2)

tx = T.get_thread_binding()

if tx < 32: # warp 0: issue tma
for w in T.unroll(waves):
tile_id = sm_num * w + block_id
bx = (tile_id // group_size) % m_blocks
by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size

if bx * block_M < M and by * block_N < N:
for k in T.serial(k_blocks):
T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1)
T.copy(
A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared[k % num_stages, :, :]
) # cannot use BufferLoad here
T.copy(B[k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared[k % num_stages, :, :])
T.mbarrier_arrive(loaded[k % num_stages])

elif tx < 64: # warp 1: issue tcgen5
for w in T.unroll(waves):
tile_id = sm_num * w + block_id
bx = (tile_id // group_size) % m_blocks
by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size

if bx * block_M < M and by * block_N < N:
T.mbarrier_wait_parity(tmem_empty[w & 1], ((w // 2) & 1) ^ 1)
for k in T.serial(k_blocks):
T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1)
if w & 1 == 0:
T.gemm(
A_shared[k % num_stages, :, :],
B_shared[k % num_stages, :, :],
C_tmem_0,
False,
False,
mbar=consumed[k % num_stages],
wg_wait=-1,
clear_accum=k == 0,
)
else:
T.gemm(
A_shared[k % num_stages, :, :],
B_shared[k % num_stages, :, :],
C_tmem_1,
False,
False,
mbar=consumed[k % num_stages],
wg_wait=-1,
clear_accum=k == 0,
)
T.tcgen05_mma_arrive(tmem_full[w & 1])

elif 128 <= tx < 256: # warp 4~7: epilogue
for w in T.unroll(waves):
tile_id = sm_num * w + block_id
bx = (tile_id // group_size) % m_blocks
by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size

if bx * block_M < M and by * block_N < N:
T.mbarrier_wait_parity(tmem_full[w & 1], (w // 2) & 1)
T.sync_threads(1, 128)
if (w & 1) == 0:
T.copy(C_tmem_0, C_local)
else:
T.copy(C_tmem_1, C_local)
T.mbarrier_arrive(tmem_empty[w & 1])

if use_tma_store:
for i in T.unroll(T.ceildiv(block_N, store_block_N)):
T.copy(C_local[:, i * store_block_N : (i + 1) * store_block_N], C_shared)
T.copy(C_shared, C[bx * block_M, by * block_N + i * store_block_N])
else:
T.copy(C_local, C_local_cast)
T.copy(C_local_cast, C[bx * block_M, by * block_N])
return C


def main():
M, N, K = 8192, 8192, 8192
block_M, block_N, block_K = 128, 256, 64
store_block_N = 128
in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float
num_stages = 4

a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
print(gemm.get_kernel_source(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages))
c = gemm(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages)

ref_c = (a.to(torch.float) @ b.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All checks passed. ✅")

tl_latency = do_bench(
lambda: gemm(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages), backend="cupti"
)
torch_latency = do_bench(lambda: a @ b, backend="cupti")
print(f"Tilelang latency: {tl_latency} ms")
print(f"Flops: {2 * M * N * K / (tl_latency / 1e3) / 1e12} TFLOPS")
print(f"Torch latency: {torch_latency} ms")
print(f"Flops: {2 * M * N * K / (torch_latency / 1e3) / 1e12} TFLOPS")


if __name__ == "__main__":
main()
11 changes: 3 additions & 8 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,7 @@ Gemm::Gemm(Array<PrimExpr> args, Map<String, ObjectRef> annotations) {
}
if (args.size() > 16) {
if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbarRegion_ =
NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = node->mbarRegion_->buffer;
} else {
node->mbar_ = std::nullopt;
node->mbar_ = Downcast<BufferLoad>(args[16]);
}
Comment on lines 83 to 86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate mbar arg type instead of silently dropping it.

If args[16] is non-null but not a BufferLoad, this currently becomes nullopt and can hide frontend regressions (e.g., still passing a region). Consider an explicit check to fail fast.

🔧 Suggested guard
   if (args.size() > 16) {
-    if (const auto *load = args[16].as<BufferLoadNode>()) {
-      node->mbar_ = Downcast<BufferLoad>(args[16]);
-    } else {
-      node->mbar_ = std::nullopt;
-    }
+    const auto *load = args[16].as<BufferLoadNode>();
+    ICHECK(load || !args[16].defined())
+        << "mbar must be a BufferLoad or null when provided";
+    node->mbar_ =
+        load ? std::optional<BufferLoad>(Downcast<BufferLoad>(args[16]))
+             : std::nullopt;
   }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (args.size() > 16) {
if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbarRegion_ =
NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = node->mbarRegion_->buffer;
node->mbar_ = Downcast<BufferLoad>(args[16]);
} else {
node->mbar_ = std::nullopt;
}
if (args.size() > 16) {
const auto *load = args[16].as<BufferLoadNode>();
ICHECK(load || !args[16].defined())
<< "mbar must be a BufferLoad or null when provided";
node->mbar_ =
load ? std::optional<BufferLoad>(Downcast<BufferLoad>(args[16]))
: std::nullopt;
}
🤖 Prompt for AI Agents
In `@src/op/gemm.cc` around lines 83 - 88, The code silently sets node->mbar_ to
nullopt when args[16] exists but is not a BufferLoadNode; instead, validate the
type and fail fast: when args.size() > 16 and args[16] is non-null, check that
args[16].as<BufferLoadNode>() is non-null and if it isn't, raise an explicit
error (e.g., TVM_PANIC/LOG(FATAL)/ICHECK/throw) describing the unexpected arg
type for mbar_; otherwise keep the existing Downcast<BufferLoad>(args[16])
assignment to node->mbar_.

}
node->cCoords_ = Array<PrimExpr>(
Expand Down Expand Up @@ -461,7 +457,7 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(can_use_tcgen5mma);
ICHECK(b_.scope() == "shared.dyn" || b_.scope() == "shared");
ICHECK(c_.scope() == "shared.tmem");
ICHECK(mbar_.has_value()) << "mbar must be provided for TCGEN5MMA";
ICHECK(mbar_.defined()) << "mbar must be provided for TCGEN5MMA";
if (a_.scope() == "shared.tmem") {
op_name = "tl::tcgen5mma_gemm_ts";
} else if (a_.scope() == "shared.dyn" || a_.scope() == "shared") {
Expand Down Expand Up @@ -492,8 +488,7 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {

auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_;
Array<PrimExpr> new_args;
auto mbarPtr =
MakeAccessPtrFromRegion(mbarRegion_, /*rw*/ 3, /*require_2d*/ true);
auto mbarPtr = MakeAccessPtrFromBufferLoad(mbar_, /*rw*/ 3);
new_args.push_back(StringImm(ss.str()));
new_args.push_back(Aptr);
new_args.push_back(Bptr);
Expand Down
5 changes: 3 additions & 2 deletions src/op/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ class GemmNode : public TileOperatorNode {
// only will be enabled under cdna mfma instructions
int kPack_ = 1;
int wgWait_ = 0;
BufferRegion mbarRegion_;
std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
tir::BufferLoad mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> cCoords_;
mutable GemmWarpPolicy policy_;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode);
Expand All @@ -157,6 +156,8 @@ class GemmNode : public TileOperatorNode {
.def_ro("clearAccum", &GemmNode::clearAccum_)
.def_ro("kPack", &GemmNode::kPack_)
.def_ro("wgWait", &GemmNode::wgWait_)
.def_ro("mbar", &GemmNode::mbar_)
.def_ro("cCoords", &GemmNode::cCoords_)
.def_ro("policy", &GemmNode::policy_);
}

Expand Down
4 changes: 1 addition & 3 deletions src/op/gemm_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ GemmPy::GemmPy(Array<PrimExpr> args, Map<String, ObjectRef> annotations) {
}
if (args.size() > 16) {
if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbarRegion_ =
NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = node->mbarRegion_->buffer;
node->mbar_ = Downcast<BufferLoad>(args[16]);
}
Comment on lines 83 to 86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fail fast when mbar is present but not a BufferLoad.

Silently ignoring non-BufferLoad inputs can mask call-site errors and lead to a missing barrier later; an explicit check makes failures clearer.

💡 Suggested fix
-  if (args.size() > 16) {
-    if (const auto *load = args[16].as<BufferLoadNode>()) {
-      node->mbar_ = Downcast<BufferLoad>(args[16]);
-    }
-  }
+  if (args.size() > 16) {
+    ICHECK(args[16].as<BufferLoadNode>())
+        << "mbar must be provided as a BufferLoad when present";
+    node->mbar_ = Downcast<BufferLoad>(args[16]);
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (args.size() > 16) {
if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbarRegion_ =
NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = node->mbarRegion_->buffer;
node->mbar_ = Downcast<BufferLoad>(args[16]);
}
if (args.size() > 16) {
ICHECK(args[16].as<BufferLoadNode>())
<< "mbar must be provided as a BufferLoad when present";
node->mbar_ = Downcast<BufferLoad>(args[16]);
}
🤖 Prompt for AI Agents
In `@src/op/gemm_py.cc` around lines 83 - 86, When handling the optional mbar
argument in the arg-parsing block, don't silently ignore non-BufferLoad values:
if args.size() > 16 then attempt the BufferLoadNode cast as currently done
(check for BufferLoadNode and set node->mbar_ via Downcast<BufferLoad>), but add
an else branch that fails fast (throw or LOG(FATAL)/CHECK) reporting that arg 16
was expected to be a BufferLoad and include the actual argument's type/name (use
whatever runtime type introspection is available on the Expr/Node to include in
the message) so callers get a clear error instead of silently skipping mbar_.

}
node->cCoords_ = Array<PrimExpr>(
Expand Down
4 changes: 1 addition & 3 deletions src/op/gemm_py.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class GemmPyNode : public TileOperatorNode {
int strideA_, strideB_;
int offsetA_, offsetB_;
PrimExpr clearAccum_ = const_false();
BufferRegion mbarRegion_;
tir::Buffer mbar_; // mbar is optional, only used for TCGEN5MMA
tir::BufferLoad mbar_; // mbar is optional, only used for TCGEN5MMA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify mbar_ usage for null-safety and initialization paths
rg -n -C3 'mbar_' src/op/gemm_py.cc src/op/gemm_py.h

Repository: tile-ai/tilelang

Length of output: 1338


🏁 Script executed:

# Also check the GemmNode definition to see the original type
rg -n -C3 'struct GemmNode\|class GemmNode' src/op/gemm_py.h src/op/gemm_py.cc

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look for mbar_ initialization in constructors or Create methods
rg -n 'mbar_\s*=' src/op/gemm_py.cc src/op/gemm_py.h

Repository: tile-ai/tilelang

Length of output: 132


🏁 Script executed:

# Check if .defined() or other null checks are used with mbar_
rg -n 'mbar_.*defined|if.*mbar_' src/op/gemm_py.cc src/op/gemm_py.h

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for how mbar_ is used/read (not just assigned)
rg -n 'mbar_' src/op/gemm_py.cc src/op/gemm_py.h -A 5 -B 2

Repository: tile-ai/tilelang

Length of output: 1511


🏁 Script executed:

# Look for uses of mbar in generated IR or kernel code
rg -n 'mbar' src/op/gemm_py.cc src/op/gemm_py.h

Repository: tile-ai/tilelang

Length of output: 286


🏁 Script executed:

# Check if there's a TVM BufferLoad definition to understand default initialization
fd -t f '\.h$' -x grep -l 'class BufferLoad\|struct BufferLoad' {} \; | head -5

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for mbar usage in gemm_py.cc to find where it's read/accessed
rg -n 'mbar' src/op/gemm_py.cc | head -20

Repository: tile-ai/tilelang

Length of output: 114


🏁 Script executed:

# Search for GemmPyNode usage in other files to see how mbar is accessed
rg -n 'GemmPyNode\|mbar' src/ --include='*.cc' --include='*.h' | grep -v 'gemm_py' | head -20

Repository: tile-ai/tilelang

Length of output: 501


🏁 Script executed:

# Check the entire gemm_py.cc to see function context where mbar is used
wc -l src/op/gemm_py.cc

Repository: tile-ai/tilelang

Length of output: 81


🏁 Script executed:

cat -n src/op/gemm_py.cc

Repository: tile-ai/tilelang

Length of output: 15164


Resolve "optional" mismatch for mbar_ (type-contract inconsistency).

mbar_ is declared as non-optional tir::BufferLoad, but the comment marks it as optional and initialization is conditional (line 85: only assigned if args.size() > 16 and args[16] is a BufferLoadNode). This creates a type-contract mismatch: the field can remain uninitialized while the type signature suggests it's always present.

Suggestion: Make this std::optional<tir::BufferLoad> to match the semantic intent, or guarantee initialization and remove the optional comment. Verify Python-side code handles undefined mbar_ safely.

🤖 Prompt for AI Agents
In `@src/op/gemm_py.h` at line 32, mbar_ is declared as a non-optional
tir::BufferLoad but is only conditionally assigned (when args.size() > 16 and
args[16] is a BufferLoadNode), causing a type-contract mismatch; change the
field declaration from tir::BufferLoad mbar_ to std::optional<tir::BufferLoad>
mbar_, update the parser/initializer (where args is inspected) to emplace/assign
mbar_ only in the conditional branch, and adjust any uses of mbar_ (check
has_value() or use value_or) so code and the Python bindings safely handle the
absent case; alternatively, if you prefer non-optional, ensure mbar_ is
unconditionally initialized in the same constructor code path and remove the
"optional" comment.

Array<PrimExpr> cCoords_;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
Expand Down Expand Up @@ -59,7 +58,6 @@ class GemmPyNode : public TileOperatorNode {
.def_ro("offsetA", &GemmPyNode::offsetA_)
.def_ro("offsetB", &GemmPyNode::offsetB_)
.def_ro("clearAccum", &GemmPyNode::clearAccum_)
.def_ro("mbarRegion", &GemmPyNode::mbarRegion_)
.def_ro("mbar", &GemmPyNode::mbar_)
.def_ro("cCoords", &GemmPyNode::cCoords_)
.def_ro("kPack", &GemmPyNode::kPack_)
Expand Down
30 changes: 30 additions & 0 deletions src/op/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,36 @@ PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region, int rw_mask,
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}

PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) {
Buffer buf = load->buffer;
int ndim = static_cast<int>(buf->shape.size());

// Compute offset using row-major layout (iterate in reverse)
PrimExpr offset = 0;
PrimExpr stride = 1;

for (int i = ndim - 1; i >= 0; --i) {
const PrimExpr &index = load->indices[i];
if (const auto *ramp = index.as<RampNode>()) {
// For Ramp, use the base
offset = offset + ramp->base * stride;
} else {
// For scalar index (IntImm or other PrimExpr)
offset = offset + index * stride;
}
stride = stride * buf->shape[i];
}

// Extent is 1 element for a single BufferLoad access
PrimExpr extent = make_const(DataType::Int(32), 1);

// Build access_ptr
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
Comment on lines 106 to 133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid int32 offset/stride to prevent overflow on large buffers.

DataType::Int(32) can overflow for large shapes and diverges from MakeAccessPtrFromRegion. Use the buffer index dtype for offset/stride/extent.

💡 Suggested fix
-  // Compute offset using row-major layout (iterate in reverse)
-  PrimExpr offset = make_const(DataType::Int(32), 0);
-  PrimExpr stride = make_const(DataType::Int(32), 1);
+  // Compute offset using row-major layout (iterate in reverse)
+  DataType idx_dtype = buf->shape[0].dtype();
+  PrimExpr offset = make_const(idx_dtype, 0);
+  PrimExpr stride = make_const(idx_dtype, 1);
@@
-  // Extent is 1 element for a single BufferLoad access
-  PrimExpr extent = make_const(DataType::Int(32), 1);
+  // Extent is 1 element for a single BufferLoad access
+  PrimExpr extent = make_const(idx_dtype, 1);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) {
Buffer buf = load->buffer;
int ndim = static_cast<int>(buf->shape.size());
// Compute offset using row-major layout (iterate in reverse)
PrimExpr offset = make_const(DataType::Int(32), 0);
PrimExpr stride = make_const(DataType::Int(32), 1);
for (int i = ndim - 1; i >= 0; --i) {
const PrimExpr &index = load->indices[i];
if (const auto *ramp = index.as<RampNode>()) {
// For Ramp, use the base
offset = offset + ramp->base * stride;
} else {
// For scalar index (IntImm or other PrimExpr)
offset = offset + index * stride;
}
stride = stride * buf->shape[i];
}
// Extent is 1 element for a single BufferLoad access
PrimExpr extent = make_const(DataType::Int(32), 1);
// Build access_ptr
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) {
Buffer buf = load->buffer;
int ndim = static_cast<int>(buf->shape.size());
// Compute offset using row-major layout (iterate in reverse)
DataType idx_dtype = buf->shape[0].dtype();
PrimExpr offset = make_const(idx_dtype, 0);
PrimExpr stride = make_const(idx_dtype, 1);
for (int i = ndim - 1; i >= 0; --i) {
const PrimExpr &index = load->indices[i];
if (const auto *ramp = index.as<RampNode>()) {
// For Ramp, use the base
offset = offset + ramp->base * stride;
} else {
// For scalar index (IntImm or other PrimExpr)
offset = offset + index * stride;
}
stride = stride * buf->shape[i];
}
// Extent is 1 element for a single BufferLoad access
PrimExpr extent = make_const(idx_dtype, 1);
// Build access_ptr
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
🤖 Prompt for AI Agents
In `@src/op/utils.cc` around lines 95 - 122, The function
MakeAccessPtrFromBufferLoad uses hard-coded DataType::Int(32) for offset, stride
and extent which can overflow for large buffers; change all occurrences of
make_const(DataType::Int(32), ...) and the IntImm for rw_mask to use the
buffer's index dtype (buf->index_dtype) instead: initialize offset and stride
with make_const(buf->index_dtype, 0/1), compute offset/stride arithmetic with
that dtype, set extent using make_const(buf->index_dtype, 1), and construct the
rw_mask as IntImm(buf->index_dtype, rw_mask) when building acc_args; update
references inside MakeAccessPtrFromBufferLoad (offset, stride, extent, acc_args)
accordingly.

}

// Maps TVM DataType to CUDA's CUtensorMapDataType enum value.
int to_CUtensorMapDataType(DataType dtype) {
CUtensorMapDataType tp;
Expand Down
4 changes: 4 additions & 0 deletions src/op/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg);
TVM_DLL PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask, bool require_2d = false);

// Build a tvm_access_ptr(handle) from a BufferLoad.
TVM_DLL PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load,
int rw_mask);

// Check if a buffer is a fragment buffer (scope == "local.fragment")
inline bool IsFragmentBuffer(const Buffer &buffer) {
return buffer.defined() && buffer.scope() == "local.fragment";
Expand Down
10 changes: 6 additions & 4 deletions tilelang/language/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,15 +805,17 @@ def cp_async_barrier_noinc(barrier: BarrierType):
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier)


def tcgen05_mma_arrive(mbar_ptr):
def tcgen05_mma_arrive(mbar: tir.Buffer | BufferLoad | PrimExpr):
"""Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer.

Parameters
----------
mbar_ptr : PrimExpr
Pointer to the mbarrier object in shared memory (e.g., Barrier*).
mbar: tir.Buffer | BufferLoad | PrimExpr
The mbarrier object in shared memory (e.g., Barrier*) or its address.
"""
return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr)
if isinstance(mbar, (tir.Buffer, BufferLoad)):
mbar = retrieve_ptr(mbar, access_type="rw")
return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar)


def ptx_mma_sm70(
Expand Down
Loading
Loading