[Example][BugFix] 1SM GEMM example on Blackwell and fix handling of mbar#1774
[Example][BugFix] 1SM GEMM example on Blackwell and fix handling of mbar#1774Rachmanino wants to merge 12 commits intotile-ai:mainfrom
mbar#1774Conversation
📝 WalkthroughWalkthroughThis PR refactors the mbar (memory barrier) representation throughout the codebase, converting from a BufferRegion-based system to a direct BufferLoad approach. It adds a new utility function Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@src/op/gemm_py.cc`:
- Around line 83-86: When handling the optional mbar argument in the arg-parsing
block, don't silently ignore non-BufferLoad values: if args.size() > 16 then
attempt the BufferLoadNode cast as currently done (check for BufferLoadNode and
set node->mbar_ via Downcast<BufferLoad>), but add an else branch that fails
fast (throw or LOG(FATAL)/CHECK) reporting that arg 16 was expected to be a
BufferLoad and include the actual argument's type/name (use whatever runtime
type introspection is available on the Expr/Node to include in the message) so
callers get a clear error instead of silently skipping mbar_.
In `@src/op/gemm_py.h`:
- Line 32: mbar_ is declared as a non-optional tir::BufferLoad but is only
conditionally assigned (when args.size() > 16 and args[16] is a BufferLoadNode),
causing a type-contract mismatch; change the field declaration from
tir::BufferLoad mbar_ to std::optional<tir::BufferLoad> mbar_, update the
parser/initializer (where args is inspected) to emplace/assign mbar_ only in the
conditional branch, and adjust any uses of mbar_ (check has_value() or use
value_or) so code and the Python bindings safely handle the absent case;
alternatively, if you prefer non-optional, ensure mbar_ is unconditionally
initialized in the same constructor code path and remove the "optional" comment.
In `@src/op/utils.cc`:
- Around line 95-122: The function MakeAccessPtrFromBufferLoad uses hard-coded
DataType::Int(32) for offset, stride and extent which can overflow for large
buffers; change all occurrences of make_const(DataType::Int(32), ...) and the
IntImm for rw_mask to use the buffer's index dtype (buf->index_dtype) instead:
initialize offset and stride with make_const(buf->index_dtype, 0/1), compute
offset/stride arithmetic with that dtype, set extent using
make_const(buf->index_dtype, 1), and construct the rw_mask as
IntImm(buf->index_dtype, rw_mask) when building acc_args; update references
inside MakeAccessPtrFromBufferLoad (offset, stride, extent, acc_args)
accordingly.
🧹 Nitpick comments (1)
src/op/gemm.cc (1)
84-88: Consider hard-failing on unexpected mbar argument type.When
args.size() > 16, a non-BufferLoad value is silently dropped. That can mask frontend mismatches. Prefer anICHECKto fail early when the caller passes the wrong type.Suggested tightening
if (args.size() > 16) { - if (const auto *load = args[16].as<BufferLoadNode>()) { - node->mbar_ = Downcast<BufferLoad>(args[16]); - } else { - node->mbar_ = std::nullopt; - } + ICHECK(args[16].as<BufferLoadNode>()) + << "mbar must be provided as BufferLoad when present"; + node->mbar_ = Downcast<BufferLoad>(args[16]); }
| if (args.size() > 16) { | ||
| if (const auto *load = args[16].as<BufferLoadNode>()) { | ||
| node->mbarRegion_ = | ||
| NormalizeToBufferRegion(Downcast<BufferLoad>(args[16])); | ||
| node->mbar_ = node->mbarRegion_->buffer; | ||
| node->mbar_ = Downcast<BufferLoad>(args[16]); | ||
| } |
There was a problem hiding this comment.
Fail fast when mbar is present but not a BufferLoad.
Silently ignoring non-BufferLoad inputs can mask call-site errors and lead to a missing barrier later; an explicit check makes failures clearer.
💡 Suggested fix
- if (args.size() > 16) {
- if (const auto *load = args[16].as<BufferLoadNode>()) {
- node->mbar_ = Downcast<BufferLoad>(args[16]);
- }
- }
+ if (args.size() > 16) {
+ ICHECK(args[16].as<BufferLoadNode>())
+ << "mbar must be provided as a BufferLoad when present";
+ node->mbar_ = Downcast<BufferLoad>(args[16]);
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if (args.size() > 16) { | |
| if (const auto *load = args[16].as<BufferLoadNode>()) { | |
| node->mbarRegion_ = | |
| NormalizeToBufferRegion(Downcast<BufferLoad>(args[16])); | |
| node->mbar_ = node->mbarRegion_->buffer; | |
| node->mbar_ = Downcast<BufferLoad>(args[16]); | |
| } | |
| if (args.size() > 16) { | |
| ICHECK(args[16].as<BufferLoadNode>()) | |
| << "mbar must be provided as a BufferLoad when present"; | |
| node->mbar_ = Downcast<BufferLoad>(args[16]); | |
| } |
🤖 Prompt for AI Agents
In `@src/op/gemm_py.cc` around lines 83 - 86, When handling the optional mbar
argument in the arg-parsing block, don't silently ignore non-BufferLoad values:
if args.size() > 16 then attempt the BufferLoadNode cast as currently done
(check for BufferLoadNode and set node->mbar_ via Downcast<BufferLoad>), but add
an else branch that fails fast (throw or LOG(FATAL)/CHECK) reporting that arg 16
was expected to be a BufferLoad and include the actual argument's type/name (use
whatever runtime type introspection is available on the Expr/Node to include in
the message) so callers get a clear error instead of silently skipping mbar_.
| PrimExpr clearAccum_ = const_false(); | ||
| BufferRegion mbarRegion_; | ||
| tir::Buffer mbar_; // mbar is optional, only used for TCGEN5MMA | ||
| tir::BufferLoad mbar_; // mbar is optional, only used for TCGEN5MMA |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify mbar_ usage for null-safety and initialization paths
rg -n -C3 'mbar_' src/op/gemm_py.cc src/op/gemm_py.hRepository: tile-ai/tilelang
Length of output: 1338
🏁 Script executed:
# Also check the GemmNode definition to see the original type
rg -n -C3 'struct GemmNode\|class GemmNode' src/op/gemm_py.h src/op/gemm_py.ccRepository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look for mbar_ initialization in constructors or Create methods
rg -n 'mbar_\s*=' src/op/gemm_py.cc src/op/gemm_py.hRepository: tile-ai/tilelang
Length of output: 132
🏁 Script executed:
# Check if .defined() or other null checks are used with mbar_
rg -n 'mbar_.*defined|if.*mbar_' src/op/gemm_py.cc src/op/gemm_py.hRepository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for how mbar_ is used/read (not just assigned)
rg -n 'mbar_' src/op/gemm_py.cc src/op/gemm_py.h -A 5 -B 2Repository: tile-ai/tilelang
Length of output: 1511
🏁 Script executed:
# Look for uses of mbar in generated IR or kernel code
rg -n 'mbar' src/op/gemm_py.cc src/op/gemm_py.hRepository: tile-ai/tilelang
Length of output: 286
🏁 Script executed:
# Check if there's a TVM BufferLoad definition to understand default initialization
fd -t f '\.h$' -x grep -l 'class BufferLoad\|struct BufferLoad' {} \; | head -5Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for mbar usage in gemm_py.cc to find where it's read/accessed
rg -n 'mbar' src/op/gemm_py.cc | head -20Repository: tile-ai/tilelang
Length of output: 114
🏁 Script executed:
# Search for GemmPyNode usage in other files to see how mbar is accessed
rg -n 'GemmPyNode\|mbar' src/ --include='*.cc' --include='*.h' | grep -v 'gemm_py' | head -20Repository: tile-ai/tilelang
Length of output: 501
🏁 Script executed:
# Check the entire gemm_py.cc to see function context where mbar is used
wc -l src/op/gemm_py.ccRepository: tile-ai/tilelang
Length of output: 81
🏁 Script executed:
cat -n src/op/gemm_py.ccRepository: tile-ai/tilelang
Length of output: 15164
Resolve "optional" mismatch for mbar_ (type-contract inconsistency).
mbar_ is declared as non-optional tir::BufferLoad, but the comment marks it as optional and initialization is conditional (line 85: only assigned if args.size() > 16 and args[16] is a BufferLoadNode). This creates a type-contract mismatch: the field can remain uninitialized while the type signature suggests it's always present.
Suggestion: Make this std::optional<tir::BufferLoad> to match the semantic intent, or guarantee initialization and remove the optional comment. Verify Python-side code handles undefined mbar_ safely.
🤖 Prompt for AI Agents
In `@src/op/gemm_py.h` at line 32, mbar_ is declared as a non-optional
tir::BufferLoad but is only conditionally assigned (when args.size() > 16 and
args[16] is a BufferLoadNode), causing a type-contract mismatch; change the
field declaration from tir::BufferLoad mbar_ to std::optional<tir::BufferLoad>
mbar_, update the parser/initializer (where args is inspected) to emplace/assign
mbar_ only in the conditional branch, and adjust any uses of mbar_ (check
has_value() or use value_or) so code and the Python bindings safely handle the
absent case; alternatively, if you prefer non-optional, ensure mbar_ is
unconditionally initialized in the same constructor code path and remove the
"optional" comment.
| PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) { | ||
| Buffer buf = load->buffer; | ||
| int ndim = static_cast<int>(buf->shape.size()); | ||
|
|
||
| // Compute offset using row-major layout (iterate in reverse) | ||
| PrimExpr offset = make_const(DataType::Int(32), 0); | ||
| PrimExpr stride = make_const(DataType::Int(32), 1); | ||
|
|
||
| for (int i = ndim - 1; i >= 0; --i) { | ||
| const PrimExpr &index = load->indices[i]; | ||
| if (const auto *ramp = index.as<RampNode>()) { | ||
| // For Ramp, use the base | ||
| offset = offset + ramp->base * stride; | ||
| } else { | ||
| // For scalar index (IntImm or other PrimExpr) | ||
| offset = offset + index * stride; | ||
| } | ||
| stride = stride * buf->shape[i]; | ||
| } | ||
|
|
||
| // Extent is 1 element for a single BufferLoad access | ||
| PrimExpr extent = make_const(DataType::Int(32), 1); | ||
|
|
||
| // Build access_ptr | ||
| PrimExpr ptype = tir::TypeAnnotation(buf->dtype); | ||
| Array<PrimExpr> acc_args{ptype, buf->data, offset, extent, | ||
| IntImm(DataType::Int(32), rw_mask)}; | ||
| return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); |
There was a problem hiding this comment.
Avoid int32 offset/stride to prevent overflow on large buffers.
DataType::Int(32) can overflow for large shapes and diverges from MakeAccessPtrFromRegion. Use the buffer index dtype for offset/stride/extent.
💡 Suggested fix
- // Compute offset using row-major layout (iterate in reverse)
- PrimExpr offset = make_const(DataType::Int(32), 0);
- PrimExpr stride = make_const(DataType::Int(32), 1);
+ // Compute offset using row-major layout (iterate in reverse)
+ DataType idx_dtype = buf->shape[0].dtype();
+ PrimExpr offset = make_const(idx_dtype, 0);
+ PrimExpr stride = make_const(idx_dtype, 1);
@@
- // Extent is 1 element for a single BufferLoad access
- PrimExpr extent = make_const(DataType::Int(32), 1);
+ // Extent is 1 element for a single BufferLoad access
+ PrimExpr extent = make_const(idx_dtype, 1);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) { | |
| Buffer buf = load->buffer; | |
| int ndim = static_cast<int>(buf->shape.size()); | |
| // Compute offset using row-major layout (iterate in reverse) | |
| PrimExpr offset = make_const(DataType::Int(32), 0); | |
| PrimExpr stride = make_const(DataType::Int(32), 1); | |
| for (int i = ndim - 1; i >= 0; --i) { | |
| const PrimExpr &index = load->indices[i]; | |
| if (const auto *ramp = index.as<RampNode>()) { | |
| // For Ramp, use the base | |
| offset = offset + ramp->base * stride; | |
| } else { | |
| // For scalar index (IntImm or other PrimExpr) | |
| offset = offset + index * stride; | |
| } | |
| stride = stride * buf->shape[i]; | |
| } | |
| // Extent is 1 element for a single BufferLoad access | |
| PrimExpr extent = make_const(DataType::Int(32), 1); | |
| // Build access_ptr | |
| PrimExpr ptype = tir::TypeAnnotation(buf->dtype); | |
| Array<PrimExpr> acc_args{ptype, buf->data, offset, extent, | |
| IntImm(DataType::Int(32), rw_mask)}; | |
| return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); | |
| PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) { | |
| Buffer buf = load->buffer; | |
| int ndim = static_cast<int>(buf->shape.size()); | |
| // Compute offset using row-major layout (iterate in reverse) | |
| DataType idx_dtype = buf->shape[0].dtype(); | |
| PrimExpr offset = make_const(idx_dtype, 0); | |
| PrimExpr stride = make_const(idx_dtype, 1); | |
| for (int i = ndim - 1; i >= 0; --i) { | |
| const PrimExpr &index = load->indices[i]; | |
| if (const auto *ramp = index.as<RampNode>()) { | |
| // For Ramp, use the base | |
| offset = offset + ramp->base * stride; | |
| } else { | |
| // For scalar index (IntImm or other PrimExpr) | |
| offset = offset + index * stride; | |
| } | |
| stride = stride * buf->shape[i]; | |
| } | |
| // Extent is 1 element for a single BufferLoad access | |
| PrimExpr extent = make_const(idx_dtype, 1); | |
| // Build access_ptr | |
| PrimExpr ptype = tir::TypeAnnotation(buf->dtype); | |
| Array<PrimExpr> acc_args{ptype, buf->data, offset, extent, | |
| IntImm(DataType::Int(32), rw_mask)}; | |
| return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); | |
| } |
🤖 Prompt for AI Agents
In `@src/op/utils.cc` around lines 95 - 122, The function
MakeAccessPtrFromBufferLoad uses hard-coded DataType::Int(32) for offset, stride
and extent which can overflow for large buffers; change all occurrences of
make_const(DataType::Int(32), ...) and the IntImm for rw_mask to use the
buffer's index dtype (buf->index_dtype) instead: initialize offset and stride
with make_const(buf->index_dtype, 0/1), compute offset/stride arithmetic with
that dtype, set extent using make_const(buf->index_dtype, 1), and construct the
rw_mask as IntImm(buf->index_dtype, rw_mask) when building acc_args; update
references inside MakeAccessPtrFromBufferLoad (offset, stride, extent, acc_args)
accordingly.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@examples/gemm_sm100/gemm_tcgen5mma_ws.py`:
- Around line 23-47: Add an explicit precondition that M, N, and K are exact
multiples of block_M, block_N, and block_K respectively to prevent out-of-bounds
writes when the kernel (which computes k_iters via T.ceildiv and writes C_shared
-> C at by*block_M, bx*block_N) runs; insert an assertion near the top of the
function before the T.Kernel block (where k_iters, A/B/C and shared buffers are
set up) that checks M % block_M == 0, N % block_N == 0, and K % block_K == 0 and
fail early if not, mirroring the guards used in other GEMM kernels.
In `@tilelang/language/builtin.py`:
- Around line 812-814: The code references mbar_ptr in the tir.call_intrin call
but only sets mbar_ptr inside the isinstance(mbar, (tir.Buffer, BufferLoad))
branch, causing UnboundLocalError when a raw PrimExpr is passed; fix by ensuring
mbar_ptr is always defined: after the existing if-block set mbar_ptr = mbar for
the PrimExpr case (or otherwise convert the PrimExpr to the expected pointer
form) so that mbar_ptr is available before calling tir.call_intrin("void",
tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr); modify the logic around the
mbar/mbar_ptr handling in the function in tilelang/language/builtin.py to cover
both tir.Buffer/BufferLoad and PrimExpr inputs.
| k_iters = T.ceildiv(K, block_K) | ||
|
|
||
| A: T.Tensor[[M, K], in_dtype] | ||
| B: T.Tensor[[K, N], in_dtype] | ||
| C = T.empty((M, N), out_dtype) | ||
|
|
||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): | ||
| A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) | ||
| B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype) | ||
| C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) | ||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
| C_shared = T.alloc_shared((block_M, block_N), out_dtype) | ||
| loaded = T.alloc_barrier([32] * num_stages) | ||
| consumed = T.alloc_barrier([1] * num_stages) | ||
| tmem_full = T.alloc_barrier([1]) | ||
|
|
||
| tx = T.get_thread_binding() | ||
|
|
||
| T.use_swizzle(8) | ||
|
|
||
| if tx < 32: # warp 0: issue tma | ||
| for k in T.serial(k_iters): | ||
| T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1) | ||
| T.copy(A[by * block_M : (by + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared[k % num_stages, :, :]) | ||
| T.copy(B[k * block_K : (k + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[k % num_stages, :, :]) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Find the file mentioned in the review
fd "gemm_tcgen5mma_ws.py" --type fRepository: tile-ai/tilelang
Length of output: 100
🏁 Script executed:
# List the repository structure to understand the codebase layout
ls -la examples/ 2>/dev/null | head -20Repository: tile-ai/tilelang
Length of output: 1244
🏁 Script executed:
# Search for other GEMM examples to understand the pattern
fd "gemm.*\.py" examples/ --type f | head -20Repository: tile-ai/tilelang
Length of output: 1130
🏁 Script executed:
# Examine the file mentioned in the review
cat -n examples/gemm_sm100/gemm_tcgen5mma_ws.pyRepository: tile-ai/tilelang
Length of output: 4292
🏁 Script executed:
# Also search for other gemm_sm100 files to get context
ls -la examples/gemm_sm100/Repository: tile-ai/tilelang
Length of output: 434
🏁 Script executed:
# Search for other GEMM examples to understand how they handle non-divisible dimensions
grep -r "T.copy.*\[.*\*.*block_" examples/gemm* --include="*.py" | head -20Repository: tile-ai/tilelang
Length of output: 2122
🏁 Script executed:
# Look at other SM100 GEMM variants to see if they have guards
cat -n examples/gemm_sm100/gemm_tcgen5mma.pyRepository: tile-ai/tilelang
Length of output: 3381
🏁 Script executed:
# Search for any guard patterns (assertions, checks) in GEMM examples
rg "assert.*%|assert.*divisible|if.*%" examples/gemm* -A 2 -B 2 | head -40Repository: tile-ai/tilelang
Length of output: 3576
Add explicit precondition to guard against non‑divisible M/N/K dimensions.
The kernel uses T.ceildiv for grid dimensions, and line 68 copies C_shared (shape [block_M, block_N]) to C[by * block_M, bx * block_N] without bounds checking. If M/N/K are not exact multiples of their block sizes, the last threadblocks will write out of bounds. The test case passes only because M=N=K=16384 are exact multiples of their block sizes. Add an assertion to enforce this precondition, as done in other GEMM kernels in the codebase.
🛡️ Suggested guard (adjust to TileLang-supported assertion APIs)
def gemm(
A,
B,
@@
):
M, N, K = T.const("M, N, K")
+ # Kernel requires exact tile coverage to avoid OOB access at line 68.
+ assert (
+ M % block_M == 0
+ and N % block_N == 0
+ and K % block_K == 0
+ ), "M/N/K must be divisible by block_M/block_N/block_K"🤖 Prompt for AI Agents
In `@examples/gemm_sm100/gemm_tcgen5mma_ws.py` around lines 23 - 47, Add an
explicit precondition that M, N, and K are exact multiples of block_M, block_N,
and block_K respectively to prevent out-of-bounds writes when the kernel (which
computes k_iters via T.ceildiv and writes C_shared -> C at by*block_M,
bx*block_N) runs; insert an assertion near the top of the function before the
T.Kernel block (where k_iters, A/B/C and shared buffers are set up) that checks
M % block_M == 0, N % block_N == 0, and K % block_K == 0 and fail early if not,
mirroring the guards used in other GEMM kernels.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py`:
- Around line 31-33: Add explicit tiling precondition checks to prevent
out-of-bounds tile loads/stores: validate at runtime that M % block_M == 0 and N
% block_N == 0 to ensure full tiles for the slices used when loading/storing
(referencing m_blocks and n_blocks), and also assert that block_N %
store_block_N == 0 (to prevent OOB access into C_local when iterating
store_block_N across a block_N tile). Implement these as early
assertions/argument checks (or raise clear exceptions) before any tensor slicing
or the compute loop so callers supplying M, N, block_M, block_N, and
store_block_N cannot trigger OOB accesses.
| m_blocks = T.ceildiv(M, block_M) | ||
| n_blocks = T.ceildiv(N, block_N) | ||
| assert K % (2 * block_K) == 0 # for simplicity |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cd /tmp/repo && find . -name "gemm_tcgen5mma_ws_persistent.py" -type fRepository: tile-ai/tilelang
Length of output: 119
🏁 Script executed:
head -150 examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py | cat -nRepository: tile-ai/tilelang
Length of output: 7610
Add tiling precondition checks to prevent OOB loads/stores.
The kernel loads/stores full tiles without boundary masking. If M or N aren't multiples of block_M or block_N, the slices on lines 63 and 65 will read past array bounds. Similarly, on line 119, if store_block_N doesn't evenly divide block_N, the loop will access out-of-bounds memory in C_local. While the test case uses compatible parameters (8192 with block size 128/256, store_block_N 128), the kernel interface accepts these as runtime parameters, allowing unsafe calls.
💡 Suggested guardrails
- assert K % (2 * block_K) == 0 # for simplicity
+ assert K % (2 * block_K) == 0, "K must be divisible by 2 * block_K"
+ assert M % block_M == 0 and N % block_N == 0, "M/N must be multiples of block_M/block_N"
+ assert store_block_N <= block_N and block_N % store_block_N == 0, "store_block_N must evenly divide block_N"🤖 Prompt for AI Agents
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py` around lines 31 - 33,
Add explicit tiling precondition checks to prevent out-of-bounds tile
loads/stores: validate at runtime that M % block_M == 0 and N % block_N == 0 to
ensure full tiles for the slices used when loading/storing (referencing m_blocks
and n_blocks), and also assert that block_N % store_block_N == 0 (to prevent OOB
access into C_local when iterating store_block_N across a block_N tile).
Implement these as early assertions/argument checks (or raise clear exceptions)
before any tensor slicing or the compute loop so callers supplying M, N,
block_M, block_N, and store_block_N cannot trigger OOB accesses.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/gemm.cc (1)
90-91:⚠️ Potential issue | 🟡 MinorPotential out-of-bounds access on
args[17]andargs[18].
cCoords_initialization accessesargs[17]andargs[18]unconditionally, but the preceding check only validatesargs.size() > 16. If the caller provides exactly 17 elements, this will cause an out-of-bounds access.Consider adding a bounds check or documenting the minimum required args size:
Suggested fix
+ ICHECK(args.size() >= 19) << "Expected at least 19 arguments for Gemm, got " << args.size(); node->cCoords_ = Array<PrimExpr>( {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
🧹 Nitpick comments (2)
src/op/gemm_py.cc (2)
83-86: Design note: Silent ignore is intentional for placeholder constants.The Python side (
gemm_op.py, lines 109-112) passestir.const(0, dtype="int32")as a placeholder whenmbarisNone. The current behavior of silently ignoring non-BufferLoadvalues is intentional to handle this placeholder case correctly.However, for debugging purposes, you might consider adding a debug log when a non-BufferLoad is encountered to help catch genuine call-site errors vs. expected placeholder values.
💡 Optional: Add debug logging
if (args.size() > 16) { if (const auto *load = args[16].as<BufferLoadNode>()) { node->mbar_ = Downcast<BufferLoad>(args[16]); + } else { + DLOG(INFO) << "mbar arg is not BufferLoad (placeholder or error): " + << args[16]->GetTypeKey(); } }
88-89: Consider adding bounds check forcCoords_access.Lines 88-89 access
args[17]andargs[18]unconditionally after only checkingargs.size() > 16. While the Python call site currently always passes 19 arguments, adding a bounds check would make this more robust against future API changes.💡 Suggested defensive check
+ ICHECK(args.size() >= 19) << "Expected at least 19 arguments for GemmPy, got " << args.size(); node->cCoords_ = Array<PrimExpr>( {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/language/gemm_op.py (1)
99-113:⚠️ Potential issue | 🔴 CriticalConvert
mbardirectly totir.BufferLoad;to_buffer_region()returnsBufferRegionwhich fails the C++BufferLoadNodecheck.When
mbaris converted viato_buffer_region(mbar, access_type="rw")without extents, it returns aBufferRegion. The C++ side at line 84 ofsrc/op/gemm.ccchecksif (const auto *load = args[16].as<BufferLoadNode>()). SinceBufferRegionis not aBufferLoadNode, this check fails andmbar_is set tostd::nullopt, disabling barrier synchronization. This later causes a crash when line 493 tries to callMakeAccessPtrFromBufferLoad(mbar_.value()).🔧 Suggested fix
if mbar is not None: assert isinstance(mbar, (tir.Buffer, tir.BufferLoad)), ( f"mbar for tcgen5mma must be a tir.Buffer or tir.BufferLoad, but got {type(mbar)}" ) - mbar = to_buffer_region(mbar, access_type="rw") + if isinstance(mbar, tir.Buffer): + mbar = tir.BufferLoad(mbar, [0])
🤖 Fix all issues with AI agents
In `@src/op/gemm.cc`:
- Around line 83-88: The code silently sets node->mbar_ to nullopt when args[16]
exists but is not a BufferLoadNode; instead, validate the type and fail fast:
when args.size() > 16 and args[16] is non-null, check that
args[16].as<BufferLoadNode>() is non-null and if it isn't, raise an explicit
error (e.g., TVM_PANIC/LOG(FATAL)/ICHECK/throw) describing the unexpected arg
type for mbar_; otherwise keep the existing Downcast<BufferLoad>(args[16])
assignment to node->mbar_.
| if (args.size() > 16) { | ||
| if (const auto *load = args[16].as<BufferLoadNode>()) { | ||
| node->mbarRegion_ = | ||
| NormalizeToBufferRegion(Downcast<BufferLoad>(args[16])); | ||
| node->mbar_ = node->mbarRegion_->buffer; | ||
| node->mbar_ = Downcast<BufferLoad>(args[16]); | ||
| } else { | ||
| node->mbar_ = std::nullopt; | ||
| } |
There was a problem hiding this comment.
Validate mbar arg type instead of silently dropping it.
If args[16] is non-null but not a BufferLoad, this currently becomes nullopt and can hide frontend regressions (e.g., still passing a region). Consider an explicit check to fail fast.
🔧 Suggested guard
if (args.size() > 16) {
- if (const auto *load = args[16].as<BufferLoadNode>()) {
- node->mbar_ = Downcast<BufferLoad>(args[16]);
- } else {
- node->mbar_ = std::nullopt;
- }
+ const auto *load = args[16].as<BufferLoadNode>();
+ ICHECK(load || !args[16].defined())
+ << "mbar must be a BufferLoad or null when provided";
+ node->mbar_ =
+ load ? std::optional<BufferLoad>(Downcast<BufferLoad>(args[16]))
+ : std::nullopt;
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if (args.size() > 16) { | |
| if (const auto *load = args[16].as<BufferLoadNode>()) { | |
| node->mbarRegion_ = | |
| NormalizeToBufferRegion(Downcast<BufferLoad>(args[16])); | |
| node->mbar_ = node->mbarRegion_->buffer; | |
| node->mbar_ = Downcast<BufferLoad>(args[16]); | |
| } else { | |
| node->mbar_ = std::nullopt; | |
| } | |
| if (args.size() > 16) { | |
| const auto *load = args[16].as<BufferLoadNode>(); | |
| ICHECK(load || !args[16].defined()) | |
| << "mbar must be a BufferLoad or null when provided"; | |
| node->mbar_ = | |
| load ? std::optional<BufferLoad>(Downcast<BufferLoad>(args[16])) | |
| : std::nullopt; | |
| } |
🤖 Prompt for AI Agents
In `@src/op/gemm.cc` around lines 83 - 88, The code silently sets node->mbar_ to
nullopt when args[16] exists but is not a BufferLoadNode; instead, validate the
type and fail fast: when args.size() > 16 and args[16] is non-null, check that
args[16].as<BufferLoadNode>() is non-null and if it isn't, raise an explicit
error (e.g., TVM_PANIC/LOG(FATAL)/ICHECK/throw) describing the unexpected arg
type for mbar_; otherwise keep the existing Downcast<BufferLoad>(args[16])
assignment to node->mbar_.
|
@LeiWang1999 fixed |
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/language/gemm_op.py (1)
99-112:⚠️ Potential issue | 🔴 Criticalmbar is incorrectly converted to BufferRegion but C++ expects BufferLoadNode.
Line 103 converts
mbarto atir.BufferRegionviato_buffer_region(mbar, access_type="rw"). However, line 112 passes thisBufferRegiondirectly asmbar_argtocall_intrin. The C++ side (gemm_py.cc lines 84–86) checksargs[16].as<BufferLoadNode>(), which will fail becauseBufferRegionis not aBufferLoadNode. This meansmbar_will remain uninitialized even when a valid barrier is provided.For consistency with
A_arg,B_arg, andC_arg(which usebuffer_region_to_tile_regionto wrap intl.regioncalls), either:
- Apply
buffer_region_to_tile_regiontombaras well, or- Pass the original
mbarwithout converting toBufferRegionon line 103, or- Extract the
BufferLoadfrom theBufferRegionbefore passing it
🧹 Nitpick comments (1)
examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py (1)
53-66: Tile-ID to (bx, by) mapping is duplicated three times.The
tile_id → (bx, by)computation (lines 55–57, 70–72, 104–106) is identical across all three warp branches. Consider extracting it into a helper or computing it once before the branch to reduce duplication and the risk of divergence if the mapping logic changes.
mbarasBufferLoadto avoid missing indexThanks @Hamerlate for providing the dev machine.
Summary by CodeRabbit
New Features
Refactor